In the evolving world of artificial intelligence, the impact of Vision Transformers (ViTs) on the field of computer vision has been both profound and transformative. Following up on our detailed exploration of the architecture of ViTs in the previous article, this article shifts focus to the practical aspects of training and fine-tuning these powerful models.
Before diving in, if you need help, guidance, or want to ask questions, join our Community and a member of the Marqo team will be there to help.
1. Recap of Vision Transformers
Before we dive into the specifics of training and fine-tuning, let's briefly recap the fundamental aspects of Vision Transformers. Vision Transformers represent a paradigm shift in how machines perceive images, moving away from the conventional convolutional neural networks (CNNs) to a method driven by self-attention mechanisms originally used in processing sequences in Natural Language Processing (NLP).
The core idea behind ViTs is to treat image patches as tokens—similar to words in text—allowing the model to learn contextual relationships between different parts of an image. Each image is split into fixed-size patches, linearly embedded, and then processed through multiple layers of transformer blocks that apply self-attention across the patches. This architecture enables ViTs to capture complex patterns and dependencies, offering a more flexible and potentially more powerful approach to image recognition than traditional methods.
Vision Transformers also scale efficiently with model size and dataset size, often surpassing the performance of CNNs when trained on large-scale datasets. This scalability, combined with their ability to generalize from fewer data when pre-trained on large datasets, makes ViTs a compelling choice for a wide range of vision tasks.
Let’s now take a look at how we can fine-tune our own Vision Transformer for image classification!
2. Fine-Tuning Vision Transformers
For this article, we will be using Google Colab (it’s free!). If you are new to Google Colab, you can follow this guide on getting set up - it’s super easy! For this module, you can find the notebook on Google Colab here or on GitHub here. As always, if you face any issues, join our Slack Community and a member of our team will help!
Install and Import Relevant Libraries
As always, we need to install the relevant libraries:
We will be utilising datasets provided by Hugging Face:
Amazing! We have installed the relevant modules needed to start fine-tuning.
Load a Dataset
This returns,
We can clearly see the features of the dataset:
Cool, let’s look at the image!
This returns the image:
Despite the image being quite blurry, we can easily detect that this is an image of a cat. When we print out the class label for this image, it should return ‘cat’. Let’s look at how we can do that.
This returns,
So, the names of the class label are indeed ‘cat’ and ‘dog’. Let’s obtain the class label for the image above.
Indeed, we get:
Nice!
There are so many images of cats and dogs in this dataset so let’s write a function to see a few more with their corresponding labels.
As expected, we have images of both dogs and cats. Note, because we’re generating random images, you won’t necessarily see the same images when executing the code yourself.
Awesome, so now we've seen what our dataset looks like, it's time to process this data!
Preparing the Images - ViT Image Processor
We’ve seen what our images look like in this dataset and so, we are in a good position to begin preparing these for our model!
When vision transformers are trained, it’s important to note that the images that are fed into the model must undergo specific transformations. Using the incorrect transformations results in your model not knowing what it’s looking at!
Here’s the output:
We’ve now prepared the images. Let’s look at processing them.
Processing the Dataset
We’ve now covered how you can read and transform images into numerical representations. Let’s combine both of these to process a single entry from the dataset.
Awesome!
We want to do this for every entry in our dataset but this can be slow, especially if you have a large dataset. We can apply a transform to the dataset where it is only applied to entries when you index them.
Of course, the ideal situation is to have a dataset like beans that is already split into train, validation and test set. However, for the purpose of this tutorial, we wanted to show you how you can fine-tune a dataset that only contains 'train' data.
Pretty cool. We now have a smaller dataset with train and validation fields.
We now apply the transform,
This means that whenever you get an entry from the dataset, the transform will be applied in real time! Take the first two entries for example:
The output:
Now the dataset is prepared, let's move onto the training and fine-tuning!
Training and Fine-Tuning
Data Collator
As we mentioned, we have batches of data which are being inputted as lists of dicts. We need to unpack these.
Evaluation Metric
We want to write a function that takes in the models prediction and computes the accuracy.
This function is used during the evaluation phase of model training to assess how well the model is performing.
Loading Our Model
Defining Training Arguments
Let’s break down each of these.
Let’s Start Training!
Let’s break down the entries:
Let’s Run the Fine-Tuning!
All that’s left to do is to run the fine-tuning.
Here’s the output:
Now, there's a few things we need to talk about here. Let's first discuss the results from the fine-tuning.
These fine-tuning results showcase several important trends and behaviors in the model's learning process over time:
- Training Loss: The training loss generally shows a declining trend as the steps increase, which is an encouraging sign of the model learning and improving from the training data. Notably, the training loss decreases substantially from the initial to the final step, with occasional upticks (such as at steps 30 and 50), which could be due to the model adjusting to complexities or nuances in the dataset.
- Validation Loss: The validation loss shows more variability compared to the training loss. It starts low, increases at certain points (notably at step 90, reaching the peak), and then decreases again. This pattern suggests that the model might be experiencing some challenges in generalizing to unseen data at certain training stages, particularly around step 90.
- Accuracy: The accuracy of the model on validation data starts very high at 99% at step 10 and fluctuates with a general decreasing trend up to step 90, where it drops to its lowest at 94.5%. However, it recovers well towards the end, returning to 99% at step 120. The high starting accuracy could suggest that the model was already quite effective even at early fine-tuning stages, possibly due to pre-training on a similar task or dataset.
- Potential Overfitting: At step 90, where the validation loss is at its highest and accuracy is at its lowest, the model is likely experiencing overfitting. This is indicated by a low training loss coupled with high validation loss and reduced accuracy.
- Overall Trends: The final steps (110 and 120) show an optimal balance with low validation loss and high accuracy, suggesting that the model has achieved a good generalization capability by the end of this fine-tuning phase. This is an encouraging sign that the fine-tuning process has successfully enhanced the model's performance on the validation dataset.
It's important to note that the dataset we choose when fine-tuning is random and so the results you get from running this Google Colab script will not be the same every time.
We selected this dataset because it's important to be aware of different trends that may happen when fine-tuning:
- High Variance in Validation Metrics: If you observe significant fluctuations in validation loss or accuracy, as compared to more stable or consistently improving training metrics, it might indicate that the model is fitting too closely to the training data and not generalizing well to new data.
- Disparity Between Training and Validation Loss: If the training loss continues to decrease while the validation loss starts to increase, it's a classic sign of overfitting. A low training loss accompanied by a high validation loss generally indicates overfitting.
- Complexity of the Model: Larger models with more parameters are more prone to overfitting because they have the capacity to learn extremely detailed patterns in the training data. This can be problematic if those detailed patterns do not apply to new data.
- Insufficient Training Data: Overfitting is more likely when the model is trained on a small dataset. A model trained on a limited amount of data might not encounter enough variability to generalize well to unseen data.
Why don't you try out different datasets yourself and let us know what results you get in our Community channel!
3. Conclusion
In this article we’ve expanded on our knowledge of Vision Transformers and performed fine-tuning to a base model for image classification. In the next article, we’ll be taking a look at multi-modal embedding models such as CLIP; diving into how they work and how they can also be fine-tuned!
4. References
[1] A. Dosovitskiy et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (2020)