Machine Learning Frontiers in 2024
Mixtures of Experts, Understanding LLMs, the Deep and Cross architecture, causal debiasing, user action sequence modeling, and more - an overview of topics covered in 2024
It’s this time of the year again — let’s rewind and see what we learned this year.
1. Mixtures of Experts
In “The Rise of Sparse Mixtures of Experts: Switch Transformers”, we explored the Switch Transformer, which for the first time introduced the concept of hard routing inside a Transformer model. Hard routing means that even though we add a large number of experts that replace the standard feed-forward layer in the Transformer block, only a single expert is activated per input token. This trick allows us to scale up the capacity of the model with a theoretical ~O(1) computational complexity, provided we have the memory to store the additional experts. There are a ton of details to get right which we’ve covered in the article (including expert parallelism, load balancing losses, expert capacity, etc), but empirically, once one gets those details right, the Switch Transformer can be trained to the same accuracy as a standard, dense Transformer 7X faster!
In Mixtures of Experts as an Auction: The BASE Transformer, we’ve explored a different mathematical formulation of the hard routing problem that has a key advantage: BASE (“balanced assignment of experts”) guarantees perfect load balance, i.e. each expert is being assigned the exact same number of tokens in each training iteration, without the need for any additional modeling tricks such as auxiliary losses. It achieves this by replacing the standard, greedy, routing with an auction algorithm, where tokens bid on experts, and experts are assigned according to the highest bidder (which is, technically, a variant of expert-choice routing — Switch uses token-choice routing). The maximum bids are simply determined by the outputs from the MoE gate. Empirically, BASE’s predictive accuracy is on par, if not slightly better, compared to the Switch Transformer, while at the same time achieving 16% better token throughput.
This led us to one of the most comprehensive works in the domain of sparse MoE, MegaBlocks, which we explored in Efficient Mixtures of Experts with Block-Sparse Matrices: MegaBlocks. Both the Switch Transformer and BASE assumed that MoEs are still being treated a set of dense matrix multiplications, that is, if we have N experts we need to run N separate matrix multiplications, one for each expert’s output. MegaBlocks changed this assumption by introducing a new compute paradigm that is better suited for the inherently sparse nature of hard routing. The key idea is to re-formulate the MoE forward run as a single, block-sparse matrix multiplication instead of one matrix multiplication per expert. In order for this to work, we need an efficient way to store and manipulate block-sparse matrices on GPUs: that’s the “BCSR” (block compressed sparse row) format, for which the authors write their own PyTorch compiler. This sparse reformulation allows MegaBlocks to scale up the number of experts without the need for capacity factor or load balancing losses. Empirically, the authors find a 2.4x speed-up compared to a dense Transformer and 1.4x speed-up compared to Tutel MoE (a variant of the Switch Transformer).
Then there was Mistral AI’s “Mixtral of Experts”, which created some hype in the media (“shaking up the AI world”, wrote one online magazine). As I explained in Demystifying Mixtral of Experts, Mixtral is like a Switch Transformer, but with a relatively small number of experts (8 vs 128), where 2 experts (not 1) are active for a given input token. This small number may explain why Mixtral gets away without expert parallelism and load balancing losses, which were critical to scale up Switch. The paper also references MegaBlocks, so it is possible that Mixtral also uses the block-sparse paradigm under the hood — unfortunately, the paper itself is too sparse on the details to know for sure.
Lastly, in Experts Everywhere: How Mixtures of Experts Turbocharge Large Language Models we explored various other ways to “MoE-ify” (yes, this is a verb, meaning incorporating MoE into) LLMs, including
MoA (Mixture of Attention heads), which integrates MoE into the Q and O projection matrices in the attention module, and outperforms the standard Transformer by 1.1 BLEU points on the WMT15 English-German translation problem dataset with the same number FLOPs,
SwitchHead, which integrates MoE into the V and O projection matrices in the attention module, which the authors find to be the best combination (with O alone accounting for most of the gains). The advantage of SwitchHead is that it does not require a large number of heads: with just 2 attention heads, SwitchHead beats an 8-head MoA model by 0.38 perplexity points,
MoD (Mixtures of Depths), which integrates MoE into the skip connections, allowing tokens to bypass a Transformer block entirely, which enables adaptive computation that can adjust automatically to the difficulty of the inputs. In experiments, the authors show that they can either get 50% savings in compute with equal loss, or 1.5% better loss with equal compute, compared to a standard Transformer model.
Fundamentally, the key challenge in sparse MoE is to marry an inherently sparse operation with a compute architecture designed for dense operations — GPUs. My sense is that MegaBlocks-like customized compilers and compute primitives are the right direction here, but it involves much more than just modeling work and is thus more difficult to operationalize.
2. On training LLMs
In A Friendly Introduction to Large Language Models (LLMs), we learned about LLM science basics, including tokenization and embeddings, self-attention, causal masking, multi-head attention, LayerNorm and the FFN layer, the Transformer block, the 3 Transformer variants, and how they are trained.
In What Exactly Happens When We Fine-Tune Large Language Models?, we dug deeper into the fine-tuning process in the context of BERT. We learned about the fine-tuning instability phenomenon, where fine-tuning results have surprisingly high variance. This instability has been shown to go away when training for a larger number of epochs, and when using the original ADAM optimizer instead of the modified version used in the original BERT paper (which excludes a bias correction term). We also learned fine-tuning can be divided into 3 phases: fitting (~epoch 1), during which the model learns simple patterns, setting (~epochs 2-5), during which both training and validation performance plateau as there are no more simple patterns left to learn, and memorization (~epochs 6+), during which the model starts memorizing individual training examples, including noise.
Lastly, in LoRA: Revolutionizing Large Language Model Adaptation without Fine-Tuning, we learned about LoRA, which has become a compelling alternative to fine-tuning. The key hypothesis behind LoRA is that the weight update matrices during fine-tuning have low intrinsic rank, and can therefore be aggressively compressed with low-rank factorization, that is, we can write the weight update for the model weights W as
∆W = AB,
where the shared dimension between the matrices A and B is the rank r, a free hyperparameter in the model. At serving time, all we need are the frozen pre-trained weights W as well as the (much smaller) matrices A and B to compute predictions y as:
y = Wx + ABx
The key finding in the LoRA paper is that r can be surprisingly small, even as small as r=1 in some problems, which allows for drastic reduction of training and inference FLOPs.
3. The Deep and Cross architecture
Let’s switch gears from LLMs to recommenders. In The Rise of Deep and Cross Networks in Recommender Systems, we explored the history of DCN, which one of the most successful neural architectures for click-through rate prediction problems and recommender systems in general. First introduced by Google in 2017, the key idea is to generate all possible feature “crosses” in a brute-force manner, which was a significant breakthrough in recommender systems and replaced the previous practice of tediously engineering cross features by hand, popularized by Google’s Wide&Deep architecture.
The crux of DCN is that the more cross layers we stack, the higher the order of feature interactions we can model. One layer results in second-order interactions (features with features), two layers result in third-order interactions (features with features with features) and so on. In theory, this should allow us to make our model extremely capable simply by adding more cross layers. In practice though, researchers soon found that DCN’s performance saturates after just 2-3 layers, which was a disappointing result.
The reason for DCN’s failure to scale were two-fold:
DCN’s interactions were vector-wise (i.e., using the dot product), not bitwise (i.e, using the element-wise product), and
As we stack more cross layers, we introduce not only more signal into the model but also disproportionately more “cross noise” from entirely meaningless feature crosses, increasing the risk of overfitting.
The first of these two limitations was addressed in DCNv2, which was the first DCN-like model modeling bit-wise feature interaction. Combined with several other innovations such as Mixtures of Experts and LoRA, the authors were able to stack 4-5 cross layers and still see improvements in model performance. This observation proved that the higher expressiveness of bitwise interactions is indeed useful for DCN-like models.
The second of DCN’s limitations, cross noise, was addressed in GDCN, short for Gated DCN, which is essentially a variant of DCNv2 with the addition of a gating network that learns a scalar weight for each feature cross. Interestingly, as we train the model, the gating network learns to suppress noisy crosses (the weight will be near 0), while up-weighing informative crosses (the weight will be near 1).
DCNv3, which we studied in DCNv3: Unlocking Extremely High-Order Feature Interactions with Exponential Cross Layers, is one of the latest members of the DCN family, and at the time of this writing leads on Criteo. The key idea in DCNv3 is to cross the crosses with themselves instead of with the original feature vector, allowing for exponential instead of linear growth of feature cross depth. In one dataset, the authors found the best performance when stacking 6 DCNv3 layers, corresponding to feature interactions of 64th order. For comparison, in a standard DCN model, one would need to stack 64 layers, not 6, resulting in ~10X the number of model parameters!
4. Debiasing recommenders
In (Some) Biases in Recommender Systems You Should Know, we examined some of the most prevalent biases in recommender systems, including clickbait bias, duration bias, position bias, popularity bias, and single-interest bias:
clickbait bias means that the model is biased in favor of clickbait content, and can be fixed with techniques such as weighted logistic regression,
duration bias means that the model is biased in favor of long videos (and against short videos). One way to fix it is use quantile-based watch-time prediction instead of predicting watch times directly,
position bias means that users are more likely to click on the first thing they see, whether it’s relevant or not. We can fix it by either weighing each training example by an estimate of its position bias, or using the positions directly as a feature in the model (but zeroing them out at serving time),
popularity bias means that the model is biased in favor of popular content instead of the unique interests of a particular user. One way to fix it is by scaling the model’s prediction logits by inverse item popularity,
single-interest bias means that the model fails to learn multiple user interests at the same time and instead over-exploits the single, most prevalent user interest. This can be fixed by calibrating the prediction scores, for example using Platt scaling.
It’s not enough to simply assume that ranking models are neutral or objective: they’ll always reflect the biases that exist in the data they are trained on. De-biasing is far from a being a solved problem, and as recommender systems continue to evolve, we can expect new biases to emerge. Coming up with innovative ways to detect, quantify and alleviate these biases remains one of the most important research domains in the industry today.
In The Problem With the Problem With Popularity Bias, we dug deeper into the problem of popularity bias, and learned that popularity itself is not really the issue. Instead, we need to distinguish between popularity driven by item quality and popularity drive by user conformity, and then only leverage the former when making predictions, ignoring the latter. We’ve seen one example of a model architecture that achieves this, the TIDE model, which uses temporal information in the popularity to estimate the user conformity contribution.
In Causal Modeling in Recommender Systems: A Primer, we introduced the concept of causal modeling in recommenders, which simply means that we employ a causal graph to inform the model architecture. We looked at 3 particular examples:
PAL models clicks as the consequence of item positions and user/item relevance, which motivates the use of a two-tower model architecture, where the second tower learns from positions alone and is only used during offline training (not serving).
DICE models clicks as the consequence of interest and conformity, which motivates the use of two sets of user/item embeddings, one for interest and one for conformity, that are learned in parallel on distinct data partitions.
MACR models clicks as the consequence of relevance, item, and user, which motivates the use of a three-tower neural network.
Lastly, in How Long Is Long? Duration Bias in Short-Form Video Recommendation, we explored the relatively new problem of duration bias in short-form video recommendation. In the absence of clicks, watch time is the main proxy for video relevance, but watch times are biased by video duration: longer videos tend to generate longer watch times, no matter how relevant they are. Thus, we need novel debiasing approaches to deal with this particular form of bias, and we’ve met 3 of them,
D2Q debiases the model by replacing watch times with watch time quantiles that are computed per video duration quantile. In A/B tests, the authors of the D2Q paper find 0.75% improvement in total watch time compared to predicting watch times directly.
DVR debiases the model by replacing watch times with “watch time gain”, i.e. the coefficient of variation of watch time within a video duration bucket. The authors report an improvement in WTG@10 of 12.3% over the naive baseline of simply predicting watch times directly.
D2Co, one of the latest debiasing approaches, explicitly models both the duration bias term as well as the noise term, the latter of which is caused by users not immediately swiping away an irrelevant video. The resulting Gaussian Mixture model outperforms both D2Q and DVR on two different production datasets.
5. User action sequence modeling
In User Action Sequence Modeling: From Attention to Transformers and Beyond, we traced the evolution of modeling techniques for user action sequences in recommender systems. Just like LLMs, recommender systems learn from sequences, which opens up the question of how much of the science behind modeling word sequences translates into modeling user action sequences.
DIN (Deep Interest Network) was among the first works that showed the potential of leveraging cross-attention with respect to the candidate instead of simply pooling all user sequences with equal weights. BERT4Rec demonstrated for the first time a ranking model trained using masked token prediction and bi-directional attention. PinnerFormer combined a Transformer architecture with the new Dense All Action loss, which resulted in a novel modeling framework that’s remarkably robust against staleness. However, both BERT4Rec and PinnerFormer failed to scale to very long user action sequences: BERT4Rec reports best results with lengths up to 200, PinnerFormer with 256.
The game changed considerably with the introduction of HSTU, which showed for the first time that an architecture purpose-built for user action sequences beats the vanilla Transformer — the authors scale HSTU up to sequences lengths of 8K, an order of magnitude longer than BERT4Rec and PinnerFormer.
In Towards Life-Long User History Modeling in Recommender Systems, we’ve seen 3 approaches to get around the target attention bottleneck, i.e. the fact that the computational complexity of target attention scales with O(LxBxd) (user history x batch size x embedding dimension):
SIM, which breaks down target attention into a retrieval stage (“general search unit”, GSU) and a pooling stage (“exact search unit", ESU). Retrieval can be solved with either “hard search”, which keeps only items from the same category as the target, or “soft search”, which keeps the k closest items in an embedding space.
ETA, which creates binary hash signatures for all items and uses Hamming distance as a proxy for relevance: the smaller the Hamming distance, the more relevant the item. The trick is to use a hashing algorithm that is locality-sensitive, such that similar vectors end up with similar hash signatures. ETA uses SimHash.
SDIM, which uses hash collision directly as a proxy for retrieval relevance, that is, we simply retrieve all items from the user history that collide with the target in hash space.
Lastly, in Target Attention Is All You Need: Modeling Extremely Long User Action Sequences in Recommender Systems, we explored Kuaishou’s TWIN family of architectures, one of the most advanced long-term history modeling solutions in the industry today. Unlike their predecessors, the TWIN family of ranking models leverages target attention in both the GSU and ESU. The trick making this work is to decompose the keys into two components, the “H” component reflecting inherent item properties and the “C”component reflecting user/item cross features. With this reformulation, we can then save compute by
caching the target attention scores coming from the H components, and
approximating the score contribution coming from the C components as a (computationally cheap) 1D bias term.
TWIN-v2 takes this one step further by adding an additional clustering module that first aggregates all video ids into video clusters before passing these clusters into TWIN. The trick is to make the clustering hierarchical and keep clustering the cluster members until each cluster size is no more than a predetermined threshold. Practically, the authors achieve an average cluster size of 10, hence reducing the cardinality and thus sequence lengths by around an order of magnitude. In Kuaishou’s case, this appears to be exactly the shrinkage needed in order to be able to model life-long user action sequences (at least for now), given that the most active power users watch hundreds of thousands of videos in a year.
Recent developments
Embedding Dimensional Collapse is a relatively recent discovery in recommender systems. When generating feature interactions inside the model (such as in DCN-like models), then the existence of a single low-rank embedding table collapses all other embedding tables into low-rank as well, hence wasting a vast amount of the overall embedding parameters allocated to the model. The multi-embedding paradigm (having multiple copies of both embedding tables and interaction modules) appears to solve this problem, albeit in a rather brute-force manner. It will be interesting to see if one could beat the multi-embedding paradigm with a few cleverly designed auxiliary loss functions instead. We’ve covered this topic in two issues,
Lately, I have been paying more attention to what is happening on the backend of ML systems. In Understanding ML Compilers: The Journey From Your Code Into the GPU, we learned the basics about PyTorch, CUDA, NVCC (CUDA’s compiler), PTX (an intermediate representation), PyTorch’s FX (another IR), and SASS (the low-level language used on GPUs). ML compilers are an indispensable part of modern AI workflows, bridging the gap between high-level frameworks and the increasingly diverse range of hardware platforms. One of the most promising research directions in this domain is the development of novel IRs, such as Google's MLIR, Apache’s Relay, Microsoft's DistIR, or OpenAI’s Triton. A common theme across these works is a growing focus on optimizing computation graphs and enabling portability across hardware.
As the field progresses, the integration of intuitive ML frontends, flexible IRs, and efficient hardware accelerators will shape the next generation of ML systems. The key in “full-stack” ML engineering will be to achieve a good balance between adaptability and performance, creating a cohesive stack where every layer — from model architecture design to hardware execution — works seamlessly to deliver optimal results.
That’s a wrap for 2024
Thanks to all my readers for your interest and support. I’m sure that we’ll keep seeing new ML discoveries, improvements, and surprises in 2025. All the best!
-Sam.