Voluntary achromatopsy and its cure in deep networks: convert a CNN from color to grayscale (and vice-versa)
ven. 03 juin 2022
Translations: fr
Converting a convolutional network between color spaces
Convolutional neural networks are now the first-class citizens of the computer vision world. The most popular baseline for image understanding tasks now consists in fine-tuning off-the-shelf deep networks, pretrained on large datasets (e.g. ImageNet), on your own specialized labeled data.
However, there can be some domain gap between the images used for pretraining and yours. Off-the-shelf models only go so far and image dimensions are not always the same as those from ImageNet. For example, most CNN use ImageNet-pretraining with images that are 224x224x3, i.e square RGB images 224 pixels wide. Convolutional networks can easily deal with varying spatial dimensions 1, but not really with a varying number of channels.
So, how do you deal with that? More precisely, how do you fine-tune your model trained on RGB images if you don't have the same number of channels? It is not uncommon to want to fine-tune a model on grayscal images. Medical imaging, for example, is a domain where this happens a lot. Various publications evoke this problem 4,5. Tajbakhsh (2016) (with more than 1000 citations!) explains that:
Because the AlexNet architecture used in our study required color patches as its input, each extracted gray-scale patch was converted to a color patch by repeating the gray channel thrice.
I'm sorry, what? This solution (duplicating three times the grayscale image to shoehorn it into the RGB format) is extremely common. But despite more people seem to think, this is really unecessary. In fact, you don't have
- to retrain from scratch your deep network on a grayscale version of ImageNet,
- nor to stack your grayscale images on 3 channels to emulate an RGB tensor,
to fine-tune your ImageNet-pretrained CNN on a grayscale dataset.
And the cherry on top? This is straigthforward to achieve and faster! So, let's look into a somewhat badly known property of convolutional layers.
Converting from color to grayscale
You have a dimension problem. Your input image has a shape \((C, W, H)\) with \(C\) the number of channels, \(W\) its width and \(H\) its height. Therefore, your first convolutional layer is parametrized by a tensor of shape \((C, C', k_w, k_h)\) where \(C'\) is the number of filters of this layer, where each filter has a kernel \((k_w, k_h)\). The next layers will deal with your \(C'\) feature maps so we only need to change the first layer so that it accepts an input \((1, W, H)\).
Convolutions: RGB to grayscale
Feeding a grayscale image into an RGB model is the most frequent usecase, judging by the number of questions on various forums (StackOverflow 10, DataScienceExchange 7,8, the Mathworks forum 9, GitHub 6, ResearchGate 11...). Worryingly, only one user (Ross Wightman) has given the correct answer about this problem and it's not even the accepted one (!) on StackOverflow. Moreover, Ross does not give the mathematical justification.
Let's consider a convolutional layer that takes an image with \(p\) input channels and outputs \(q\) feature maps. This layer has \(p \times q\) filters (or kernels) \(w_{i,j}\). Let \(b_j\) denote the bias vector (one for each output feature map), \(x_i\) the input activations and \(z_j\) the output activations. By definition, the convolutional layer follows the equation:
where \(\star\) denotes either crossed-correlation or convolution (there are no difference for our purpose, we only need distributivity). With an RGB image, we have \(p = 3\) and \(x_i\) represents the red, green and blue channels.
If all channels are identical, then all \(x_i\) are equal: \(\forall i \in [ 1, p ] ~~ x_i = x^*\).
Then the convolutional layer equation rewrites itself as:
So, if you duplicate your grayscale images three times to emulate an RGB image, the convolution on 3 channels is actually identical to a convolution on 1 channel using a filter \(w^*_j\) corresponding to the sum of the kernels \(w_{i,j}\) over \(i\). The bias vector remains the same.
See? This is an easy modification. In PyTorch, you only need a few lines of code to achieves thie transformation:
from torch import nn
def net_rgb2gray(net, layer_name='conv1'):
layer = net.__getattr__(layer_name)
gray_weights = layer.weight.sum(dim=1, keepdim=True)
use_bias = layer.bias is not None
gray_layer = nn.Conv2d(in_channels=1, kernel_size=layer.kernel_size,
out_channels=layer.out_channels, padding=layer.padding,
stride=layer.stride, dilation=layer.dilation, bias=use_bias)
gray_layer.weight.data = gray_weights
if use_bias:
gray_layer.bias.data = layer.bias.data
net.__setattr__(layer_name, gray_layer)
return net
Grayscale to color
On the opposite, let us consider a linear transformation \(\mathcal{T}\) that a maps a color image \(\rightarrow\) grayscale, such as:
For example, luminance in most image and video formats is obtained with \(\lambda_R = 0,2126,~\lambda_V = 0,7152\) and \(\lambda_B = 0,0722\), to account for the varying sensitivity of the human eye to different colors. This is a standard transformation implemented in most image processing libraries such as PIL and scikit-image.
Consider a deep net trained on grayscale images \(\mathcal{I}^*\) obtained by the transformation \(\mathcal{T}\) applied on color images \(\mathcal{I}\). The activations \(z_j\) from the first layer are obtained through the equation:
By identification, we can now build a new model that takes an image with \(p\) channels as input (typically, \(p=3\)). We only need to replace the original convolution by a new layer using the initialization: \(w_{i,j} = \lambda_i \cdot w^*_j\). The following code snippet illustrates how to achieve this in a few lines of PyTorch:
def net_gray2rgb(net, lambdas, layer_name='conv1'):
layer = net.__getattr__(layer_name)
rgb_weights = lambdas * layer.weight.repeat(1, 3, 1, 1)
use_bias = layer.bias is not None
rgb_layer = nn.Conv2d(in_channels=3, kernel_size=layer.kernel_size,
out_channels=layer.out_channels, padding=layer.padding,
stride=layer.stride, dilation=layer.dilation, bias=use_bias)
rgb_layer.weight.data = rgb_weights
if use_bias:
rgb_layer.bias.data = layer.bias.data
net.__setattr__(layer_name, rgb_layer)
return net
Benefits
Why do this? Since the activations are the same, the models should achieve identical performances, whether you use this conversion or you stack your grayscale image three times. But stacking is so inelegant! Moreover, there is a (theoretical) slight advantage to doing the proper conversion: (slightly) less parameter and a (very small) improvement in computation time.
Indeed, since you don't duplicate your channels, the first layer that had \((3, m, k, k)\) parameters now only has \((1, m, k, k)\). This is 3x less parameters! For a standard ResNet, this means 6272 parameters less! Ha! Well, sure, even for a small reduced ResNet-20, this account for a less than 1% gain total. But think of all the free memory!2
This also removes \(2/3\) of the convolutions from the first layer. Those convolutions tend be the most expensive since a) they are applied on the full resolution image, b) they use larger kernels (e.g. 7x7 instead of 3x3). This gives a small performance improvement both during training and evaluation. During inference, this results in a 1.5% reduction in latency for a ResNet-20. Huzzah!3
So. Overall, the main benefit is that it is the way. It is more elegant to avoid redundant computations or useless parameters. Barely anything more. But sometimes, isn't it enough?
-
unless you use fully connected layers, but you shouldn't. At least since ResNet, nearly all CNN use average pooling to deal with images of different sizes. In any case, resizing is always an option but yuk! ↩
-
And for VGG-16, the number of parameters is heavily dominated (by several magnitudes) by the fully connected layers, so you won't see the difference. But, hey, I tried, okay. ↩
-
I did not measure any significant improvement on GPU, likely because the convolutions are parallelized anyway. ↩
-
Nima TAJBAKHSH et al. “Convolutional Neural Networks for Medical Image Analysis : Full Training or Fine Tuning ?”. In : IEEE Transactions on Medical Imaging 35.5 (mai 2016), p. 1299-1312. ↩
-
Brady KIEFFER et al. “Convolutional neural networks for histopathology image classification : Training vs. Using pre-trained networks”. In : 2017 Seventh International Conference on Image Processing Theory, Tools and Applications (IPTA). Nov. 2017, p. 1-6. ↩
-
https://datascience.stackexchange.com/questions/22684/is-it-possible-to-use-grayscale-images-to-existing-model ↩
-
https://datascience.stackexchange.com/questions/12717/convert-filters-pre-trained-with-imagenet-to-grayscale ↩
-
https://fr.mathworks.com/matlabcentral/answers/374950-how-can-i-make-alexnet-accept-277x277x1-images ↩
-
https://stackoverflow.com/questions/51995977/how-can-i-use-a-pre-trained-neural-network-with-grayscale-images ↩
-
https://www.researchgate.net/post/How-To-use-gray-level-images-Ultrasound-in-CNN-as-they-have-1-channel-and-CNN-use-3-channels-RGB ↩
Category: science Tagged: cnn deep learning computer vision
Translations: fr