AI encoders demystified
When beginning the process of implementing a new deep learning / AI model, alongside deciding on the input and outputs of the model and converting your data to tfrecord files (file extension: .tfrecord.gz
; you really should, the performance boost is incredible) is designing the actual architecture of the model itself. It might be tempting to shove a bunch of CNN layers at the problem to make it go away, but in this post I hope to convince you to think again.
As with any field, extensive research has already been conducted, and when it comes to AI that means that the effective encoding of various different types of data with deep learning models has already been carefully studied.
To my understanding, there are 2 broad categories of data that you're likely to encounter:
- Images: photos, 2D maps of sensor readings, basically anything that could be plotted as an image
- Sequenced data: text, sensor readings with only a temporal and no spatial dimension, audio, etc
Sometimes you'll encounter a combination of these two. That's beyond the scope of this blog post, but my general advice is:
- If it's 3D, consider it a 2D image with multiple channels to save VRAM
- Go look up specialised model architectures for your problem and how well they work/didn't work
- Use a Convolutional variant of a sequenced model or something
- Pick one of these encoders and modify it slightly to fix your use case
In this blog post, I'm going to give a quick overview of the various state of the art encoders for the 2 categories of data I've found on my travels so far. Hopefully, this should give you a good starting point for building your own model. By picking an encoder from this list as a starting point for your model and reframing your problem just a bit to fit, you should find that you have not only a much simpler problem/solution, but also a more effective one too.
I'll provide links to the original papers for each model, but also links to any materials I've found that I found helpful in understanding how they work. This can be useful if you find yourself in the awkward situation of needing to implement them yourself, but it's also generally good to have an understanding of the encoder you ultimately pick.
I've read a number of papers, but there's always the possibility that I've missed something. If you think this is the case, please comment below. This is especially likely if you are dealing with audio data, as I haven't looked into handling that too much yet.
Finally, also out of scope of this blog post are the various ways of framing a problem for an AI to understand it better. If this interests you, please comment below and I'll write a separate post on that.
Images / spatial data
ConvNeXt: Image encoders are a shining example of how a simple stack of CNNs can be iteratively improved through extensive testing, and in no encoder is this more apparent than in ConvNeXt.
In 1 sentence, a ConvNeXt model can be summarised as "a ResNet but with the features of a transformer". Researchers took a bog-standard ResNet, and then iteratively tested and improved it by adding various features that you'd normally find in a transformer, which results in the model performant (*read: had the highest accuracy) encoder when compared to the other 2 on this list.
ConvNeXt is proof that CNN-based models can be significantly improved by incorporating more than just a simple stack of CNN layers.
- Paper: A ConvNet for the 2020s
- Explanation: ConvNext: The Return Of Convolution Networks
- Implementation: https://github.com/leanderme/ConvNeXt-Tensorflow
Vision Transformer: Born from when someone asked what should happen if they tried to put an image into a normal transformer (see below), Vision Transformers are a variant of a normal transformer that handles images and other spatial data instead of sequenced data (e.g. sentences).
The current state of the art is the Swin Transformer to the best of my knowledge. Note that ConvNeXt outperforms Vision Transformers, but the specifics of course depend on your task.
- Paper: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
- Explanation: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows - Committed towards better future
- Implementation: https://github.com/microsoft/Swin-Transformer (PyTorch, if you know of a simple Tensorflow implementation please let me know)
ResNet: Extremely popular, you may have heard of the humble ResNet before. It was originally made a thing when someone wanted to know what would happen if they took the number of layers in a CNN-based model to the extreme. It has a 'residual connection' between blocks of layers, which avoids the 'vanishing gradient problem' which in short is when your model is too deep, and when it backpropagates the error the gradient it adjusts the weight by becomes so small that it has hardly any effect.
Since this model was invented, skip connections (or 'residual connections', as they are sometimes known) have become a regular feature in all sorts of models - especially those deeper than just a couple of layers.
Despite this significant advancement though, I recommend using a ConvNeXt encoder instead of images - it's better than and the architecture more well tuned than a bog standard ResNet.
Sequenced Data
Transformer: The big one that is quite possibly the most famous encoder-decoder ever invented. The Transformer replaces the LSTM (see below) for handling sequenced data. It has 2 key advantages over an LSTM:
- It's more easily paralleliseable, which is great for e.g. training on GPUs
- The attention mechanism it implements revolutionises the performance of like every network ever
The impact of the Transformer on AI cannot be overstated. In short: If you think you want an LSTM, use a transformer instead.
- Paper: Attention is all you need
- Explanation: The Illustrated Transformer
- Implementation: Many around, but I have yet to find a simple enough one, so I implemented it myself from scratch. While I've tested the encoder and I know it works, I have yet to fully test the decoder. Once I have done this, I will write a separate blog post about it and add the link here.
LSTM: Short for Long Short-Term Memory, LSTMs were invented in 1997 to solve the vanishing gradient problem, in which gradients when backpropagating back through recurrent models that handle sequenced data shrink to vanishingly small values, rendering them ineffective at learning long-term relationships in sequenced data.
Superceded by the (non recurrent) Transformer, LSTMs are a recurrent model architecture, which means that their output feeds back into themselves. While this enables some unique model architectures for learning sequenced data (exhibit A), it also makes them horrible at being parallelised on a GPU (the dilated LSTM architecture attempts to remedy this, but Transformers are just better). The only thing that stands them apart from the Transformer is that the Transformer is somehow not built-in into Tensorflow as standard yet, whereas LSTMs are.
Just like Vision Transformers adapted the Transformer architecture for multidimensional data, so too are Grid LSTMs to normal LSTMs.
- Paper: Long Short-Term Memory
- Explanation: Understanding LSTMs,
- Implementation: Tensorflow:
tf.keras.layers.LSTM
, PyTorch:torch.nn.LSTM
Conclusion
In summary, for images encoders in priority order are: ConvNeXt, Vision Transformer (Swin Transformer), ResNet, and for text/sequenced data: Transformer, LSTM.
We've looked at a small selection of model architectures for handling a variety of different data types. This is not an exhaustive list, however. You might have another awkward type of data to handle that doesn't fit into either of these categories - e.g. specialised models exist for handling audio, but the general rule of thumb is an encoder architecture probably already exists for your use-case - even if I haven't listed it here.
Also of note are alternative use cases for data types I've covered here. For example, if I'm working with images I would use a ConvNeXt, but if model prediction latency and/or resource consumption mattered I would consider using a MobileNet](https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v3), which while a smaller model is designed for producing rapid predictions in lower resource environments - e.g. on mobile phones.
Finally, while these are encoders decoders also exist for various tasks. Often, they are tightly integrated into the encoder. For example, the U-Net is designed for image segmentation. Listing these is out of scope of this article, but if you are getting into AI to solve a specific problem (as is often the case), I strongly recommend looking to see if an existing model/decoder architecture has been designed to solve your particular problem. It is often much easier to adjust your problem to fit an existing model architecture than it is to design a completely new architecture to fit your particular problem (trust me, I've tried this already and it was a Bad Idea).
The first one you see might even not be the best / state of the art out there - e.g. Transformers are better than the more widely used LSTMs. Surveying the landscape for your particular task (and figuring out how to frame it in the first place) is critical to the success of your model.