How to fine-tune for transfer learning
Off with his head!
fastai has specific applications at the top layer: computer vision, natural language processing, and tabular. We've already covered the architectures that we can use to train such models, but we haven't explored what fastai does in the application APIs that allow us to use these models, either to train them from scratch or to fine-tune them.
All deep learning models have a body and a head. The body is where majority of its learning occurs and where it takes the input and outputs activations. These activations are given to the head where the decision making occurs - the decision making for the task the model is specifically trained for. So, when we're transfer learning, we'll have to cut the head off the pretrained model and give it a new head. Then, we train the model using discriminative learning rates: different learning rates for the body and the head (and for early and later epochs).
With computer vision, we either use cnn_learner
for classification, or unet_learner
for generative vision models.
In cnn_learner
, we pass the architecture we want to use for the body of the network. When we pass a pretrained network, fastai downloads the pretrained weights and prepares it for transfer learning.
First, it cuts the head of the network; with resnet, we cut off everything from the adaptive average pooling layer onwards. However, we can't just search for that layer. Instead, fastai has a model_meta
dictionary that stores the index to cut, what function is at that index, and the stats needed for normalization for that architecture. For instance, the model_meta
for resnet50 is:
model_meta[resnet50]
So for a resnet50 architecture, we keep all the layers prior to the cut point of -2
to get the body of the model that we can use for transfer learning. The head, which is specialized for ImageNet classification, is replaced by a new head, which we can make using create_head
:
nf, n_out = 20, 2
create_head(nf, n_out)
With create_head
, we have to specify how many in-channels and how many out-channels we need for our last layer. Optionally, we can change how many additional linear layers (lin_ftrs
), how much dropout to use after each one (ps
), batch normalization (first_bn
and bn_final
), and what kind of pooling to use (pool
and concat_pool
).
By default, fastai uses AdaptiveConcatPool2d
which applies both average pooling and max pooling.
Additionally, fastai adds two linear layers since having more than one linear layers allow transfer learning to be used more quickly and easily when transferring a pretrained model to a very different domain; one linear layer is unlikely to be enough.
To get the new body, we use the create_body
function:
arch, cut = resnet50, model_meta[resnet50]['cut']
create_body(arch, cut=cut)
Before covering unet_learner
, let's talk about generative vision models.
Generative vision models are different from regular classification models in that we're trying to predict an image, not labels.
Some generative vision models include segmentation, where you predict an image where each pixel is given a label; super-resolution where you increase the resolution of an image; colorization where you add colour to a greyscale image; and style transfer where you convert an image to a different style, like from a picture to a painting.
So, here comes unet, which gets its name from its shape: a U.
unet_learner
takes the body of a desired architecture, like resnet, and then concatenate a new head, which performs the generative task.
How would we create the new head? One way, called nearest neighbour interpolation, would be to take each pixel and replace it with four new pixels of the same value. Then this nearest neighbour interpolation layer would be interspersed between stride-1 convolutional layers. In a way, you can think of it as upscaling the image (nearest neighbour interpolation), and letting the model learn how to upscale the image (stride-1 convolutional layers).
Another approach is called transposed convolution where instead of downscaling with strides, we upscale by adding a zero padding around all pixels in the input; hence, this approach is also called stride-half convolution. To implement transposed convolution, you can pass transpose=True
to ConvLayer
. Transposed convolution looks like this:
However, neither of these methods work really well for training a model, but they do show how we can upscale an image. Why wouldn't they work well? Because we're trying to upscale with a really small image from our body's output.
Like resnet, unet incorporates skip connections by skipping the activations in the body of the resnet to the activations of the transposed convolution on the new head of the architecture:
The above image gives an idea of what the unet could look like; however, it's using a normal CNN instead of a resnet since the idea of resnet came after this image. With this image, we have 2 $\times$ 2 max pooling layers ("max pool 2x2"; red arrows) instead of stride-2 convolutions; transposed convolutions ("up-conv"; green arrows); and skip ("cross") connections ("copy and crop"; grey arrows).
Through skip connections, the input to the transposed convolutions aren't just the lower-resolution images from the previous layer, but also the higher-resolution images from the opposite side (the body).
The only downside with unets, like normal CNNs, is that they're dependent on the image size. So, unet_learner
uses a DynamicUnet
class that automatically generates an architecture of the right size based on the given data.
We apply the same thing of body and head with NLP. Here, we have a pretrained AWD-LSTM language model that we want to use for classification. We don't have a model_meta
for NLP since we mainly use AWD-LSTM. Instead, we just select the stacked RNN for the encoder in the language model, which is just a single PyTorch module (the body). We remove the head that takes the activations from the body and maps it to a word in the vocab. Ultimately, we're left with a model that can take give an activation for each word in a given sequence.
To fine-tune this model for text classification, we use BPTT for Text Classification (BPT3C):
At each epoch (where we call forward
in our Module
subclass), we get a document x
that's divided into fixed-length batches of size b
(n
$\times$ b
). We have a for
loop, which loops over each batch. At the beginning of each batch, the model is initialized with the final state of the previous batch; the activations of each batch are stored for average and max concatenated pooling. Then, gradients are back-propagated to the batches whose hidden states contributed to the final prediction (but in practice, we use variable length backpropagation for truncated-BPTT to avoid GPU memory overload and exploding gradients).
Similar to computer vision, we add linear layers to the head of the model for classification instead of predicting the next word. We even apply average and max concatenated pooling, except we pool over RNN sequences instead of CNN grid cells.
What fastai does in DataLoaders
for BPT3C is ensure each sequence in x
are of size b
by padding them with a special token called xxpad
. To be efficient, the texts are sorted so that we minimize the number of xxpad
tokens used by having texts of already similar sizes in the same batch.
Tabular domains are kind of special in that we can't really apply transfer learning. Instead, for tabular data sets (and collaborative filtering using deep learning), we use fastai's TabularModel
. In its forward
function, we have:
def forward(self, x_cat, x_cont=None):
if self.n_emb != 0:
x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
x = torch.cat(x, 1)
x = self.emb_drop(x)
if self.n_cont != 0:
if self.bn_cont is not None: x_cont = self.bn_cont(x_cont)
x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
return self.layers(x)
Where we check if there's embeddings for the categorical variables:
if self.n_emb != 0:
If there is, then we get the activations from each embedding, concatenate them into a single tensor, and then apply dropout:
x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
x = torch.cat(x, 1)
x = self.emb_drop(x)
Similarly for the continuous variables, we check if there is any:
if self.n_cont != 0:
Then if there is, we apply batchnorm if toggled, then concatenate the activations for the categorical and continuous inputs:
if self.bn_cont is not None: x_cont = self.bn_cont(x_cont)
x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
Finally, we pass these activations into the layers of the model (batchnorm, dropout, and linear layers):
return self.layers(x)
Now that we've covered most of the theory in the fastai course, we move onto practical deep learning. In theory, you have unlimited data, memory, and time. In this case, you train a huge model on all the data for a long time to get the ultimate model. However, we're limited in practice, so we have to find ways to get more data, make our models more efficient and more effective.
So, we first need to get our model to overfit since that means we're reaching the limit of our model with our current data, memory, and time.
To improve our training, we have to:
- Get, or create, more data: we can simply get more data, but sometimes we can't; so, we can add more labels to our existing data that creates additional tasks for our model to solve.
- Create more data through data augmentation: maybe creating more labels isn't enough; then, we can create additional synthetic data through more or different data augmentation techniques. With computer vision, Mixup tends to work very well.
- More generalizable architecture: we've gotten as much data as we can and we're taking advantage of all labels we can use, but we're still overfitting. Now we can actually start making changes to the model itself. We begin by thinking of ways to have a more generalizable architecture. The most basic way would be to add batch normalization.
- More regularization applications: so, having more generalizability was still not enough; okay, let's try regularization. We can try adding dropout to the last layer or two (or more like in AWD-LSTM). In general, a model with more regularization is more flexible and more accurate than a smaller model with less regularization.
- Simpler architectures: we leave this stage for last; if having more data and label applications, adding more generalizability, and regularization all didn't help with overfitting, maybe our model is too complicated for our task. As our final struggle, we'll move onto a smaller version of our chosen architecture, or even a simpler one in general.
Overall, we don't want to start at step 5 and move up (unless it's taking up too much time or memory with your current architecture); we want to begin at the top and make our way down: reducing the size of your model reduces the capability of your model to learn subtle relationships in your data.
With computer vision and NLP, we often use transfer learning instead of training from scratch since we'll usually get better results while being able to use less data and spend less time, money, and effort with getting started. In fine-tuning our pretrained models, we have to cut the head off and add a new head to the pretrained body (the body is called encoder for language models). Often, the body is two linear layers with average and max concatenated pooling, dropout, and batchnorm mixed in.
For NLP, we also have to apply truncated-BPTT before the new head since we're no longer predicting the next word, but classifying the text. Truncated-BPTT will give us activations for each batch where each batch's activations remembers something from the preceding batches. These activations are then passed to the new head.
With tabular data, we can't really apply transfer learning since the tasks tend to be very different from one data set to the other. So, we just covered how fastai's TabularModel
works by going over its forward
function. Overall, we prepare the activations for the categorical and continuous variables before passing them into the layers of the model.
Lastly, we went over how to train deep learning models in practice. We want to get into a state where we overfit with our model before we try anything else. Then, we follow a procedure of getting more data, applying data augmentation, adding more generalizability, implementing regularization, and finally, reducing architecture complexity. Ideally, we only move to smaller models when we run out of time or memory.
In the next blog, I'll remake the Siamese pair model for pet breeds we did here. Instead of having two passes to the model and comparing output labels, we'll have a single model that takes in two images and tells us if they're of the same breed. We'll review fastai's mid-level API and fine-tune the resnet architecture using the method we discussed here. We'll also have to go over how the Learner's splitter works since we optimize the head and body differently for the first few epochs.