🧠
AI
  • Artificial Intelligence
  • Intuitive Maths behind AI
    • Probability
    • Information Theory
    • Linear Algebra
    • Calculus
  • Overview
  • Research Ideas and Philosophy
  • Basic Principles
  • Information Theory
    • Entropy
    • Log Probability
  • Probability & Statistics
    • Random Variables
    • Probability
      • Probablistic Equations
      • Bayes Theorem
      • Probability Distributions & Processes
    • Statistics
      • Measures
      • Z-Scores
      • Covariance and Correlation
      • Correlation vs Dependance
    • Mahalanobis vs Chi-Squared
    • Uncertainty
    • Statistical Inference
      • Graphical Models
      • Estimator vs Parameter
      • Estimation
      • Bayesian/Probabilistic Inference
        • Probabilistic Modelling
        • Problems of Bayesian Inference
        • Conjugate Priors
        • Dirichlet Distribution/Process
        • Posterior Predictive Distribution
      • Sampling-Based Inference
    • Sampling
      • Rejection Sampling
      • Reservoir Sampling
      • Thompson Sampling
    • Bayesian Inference
    • Regression
    • Markov
    • Monte Carlo
      • Monte Carlo Estimators
      • Importance Sampling
    • Kernel Density Estimation
    • Gaussian Processes
    • Gaussian Soap Bubble
  • Linear Algebra
    • Vector Space and Matrices
    • Geometry of System of Linear Equations
    • Determinants
    • Transformations
    • Geometrical Representation
    • Positive (Semi)Definite Matrices
    • Matrix Interpretation
    • Dot Product as Linear Transformation and Duality of Vector-Linear Transformation
    • Norms
    • Linear Least Square
    • Matrix Decomposition
      • QR Decomposition
      • Cholesky Decomposition
      • Eigen Value Decomposition
      • SVD - Singular Value Decomposition
    • Matrix Inversion
    • Matrix Calculus
    • Matrix Cookbook
    • Distributed Matrix Algebra
    • High Dimensional Spaces
  • Optimization
    • Derivatives
      • Partial Derivative
      • Directional Derivative
      • Gradient
      • Jacobian
    • Regularization
    • Gradient Descent
    • Newton's Method
    • Gauss-Newton
    • Levenberg–Marquardt
    • Conjugate Gradient
    • Implicit Function Theorem for optimization
    • Lagrange Multiplier
    • Powell's dog leg
    • Laplace Approximation
    • Cross Entropy Method
    • Implicit Function Theorem
  • Statistical Learning Theory
    • Expectation Maximization
  • Machine Learning
    • Clustering
    • Bias Variance Trade-off
  • Deep Learning
    • PreProcessing
    • Convolution Arithmetic
    • Regularization
    • Optimizers
    • Loss function
    • Activation Functions
    • Automatic Differentiation
    • Softmax Classifier and Cross Entropy
    • Normalization
    • Batch Normalization
    • Variational Inference
    • VAE: Variational Auto-Encoders
    • Generative vs Discriminative
      • Generative Modelling
    • Making GANs train
    • Dimensionality of Layer Vs Number of Layers
    • Deep learning techniques
    • Dilated Convolutions
    • Non-Maximum Suppression
    • Hard Negative Mining
    • Mean Average Precision
    • Fine Tuning or Transfer Learning
    • Hyper-parameter Tuning
  • Bayesian Deep Learning
    • Probabilistic View
    • Uncertainty
    • Variational Inference for Bayesian Neural Network
  • Reinforcement Learning
    • General
    • Multi-armed Bandit
    • Imitation Learning
    • MDP Equations
    • Solving MDP with known Model
    • Value Iteration
    • Model Free Prediction and Control
    • Off Policy vs On Policy
    • Control & Planning from RL perspective
    • Deep Reinforcement Learning
      • Value Function Approximation
      • Policy Gradient
        • Algorithms
    • Multi Agent Reinforcement Learning
    • Reinforcement Learning - Sutton and Barto
      • Chapter 3: Finite Markov Decision Processes
      • Chapter 4: Dynamic Programming
    • MBRL
  • Transformers
    • Tokenziation
    • Embedding
      • Word Embedding
      • Positional Encoding
    • Encoder
    • Decoder
    • Multi-head Attention Block
    • Time Complexities of Self-Attention
    • KV Cache
    • Multi-head Latent Attention
    • Speculative Decoding
    • Flash Attention
    • Metrics
  • LLMs
    • LLM Techniques
    • LLM Post-training
    • Inference/Test Time Scaling
    • Reasoning Models
    • Reward Hacking
  • Diffusion Models
    • ImageGen
  • Distributed Training
  • State Space Models
  • RLHF
  • Robotics
    • Kalman Filter
    • Unscented Kalman Filter
  • Game Theory and ML
    • 1st Lecture - 19/01
    • Lecture 2 - 22/01
    • Lecture 4: Optimization
  • Continual Learning
    • Lecture - 21/01
    • iCaRL: Incremental Classifier and Representation Learning
    • Variational Continual Learning
  • Computer Vision
    • Hough Transform
    • Projective Geometry
      • Extrinsic and Intrinsic Parameters
      • Image Rectification
    • Tracking
    • Optical Flow
    • Harris Corner
    • Others
  • Papers
    • To Be Read
    • Probabilistic Object Detection and Uncertainty Estimation
      • BayesOD
      • Leveraging Heteroscedastic Aleatoric Uncertainties for Robust Real-Time LiDAR 3D Object Detection
      • Gaussian YOLOv3
      • Dropout Sampling for Robust Object Detection in Open-Set Condition
      • *Sampling Free Epistemic Uncertainty Estimation using Approximated Variance Propagation
      • Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
      • Can We Trust You? On Calibration of Probabilistic Object Detector for Autonomous Driving
    • Object Detection
    • Temporal Fusion in Object Detection/ Video Object Detection
    • An intriguing failing of convolutional neural networks and the CoordConv solution
    • A Neural Algorithm of Artistic Style - A.Gatys
  • Deep Learning Book
    • Chapter 4: Optimization
    • Chapter 5: Machine Learning Basics
    • Chapter 6: Deep FeedForward Networks
  • Python
    • Decorators
    • Packages
      • Pip
    • Gotchas
    • Async functions
  • Computer Science
  • TensorFlow
  • Pytorch
    • RNN/LSTM in Pytorch
    • Dataset/ Data loader
    • Resuming/Loading Saved model
  • Programming
    • Unit Testing
    • How to write code
  • General Software Engineering
    • SSH tunneling and Ngrok
  • How To Do Research
  • Resources
  • ROS for python3
  • Kitti
Powered by GitBook
On this page
  1. Transformers

Multi-head Latent Attention

PreviousKV CacheNextSpeculative Decoding

Last updated 3 months ago

How to reduce KV cache size compared to alternative methods such as group-query attention or multi-query attention.

In Multi-query attention, basically there's single key value shared across all the attention heads, instead of key value for each of the head.

In this way, we have have to cache smaller size of key-value. But this leads to to compromising on the model performance as we are reducing the parameters of the model.

Multi Head Latent Attention

How do we get key and value

Let xxx be the input, and how we can get key and value k,vk,vk,v, is by using a full-connected layer i.e multiplying by dense matrixes Wk∈R(nhΓ—dhead)Γ—dmodelW_k \in \mathcal{R}^{(n_h \times d_{head}) \times d_{model}}Wkβ€‹βˆˆR(nh​×dhead​)Γ—dmodel​ and Wv∈R(nhΓ—dhead)Γ—dmodelW_v \in \mathcal{R}^{(n_h \times d_{head}) \times d_{model}}Wvβ€‹βˆˆR(nh​×dhead​)Γ—dmodel​, where nh,dhead,dmodeln_h, d_{head}, d_{model}nh​,dhead​,dmodel​ are number of heads, vector dim in each head, model dimension respectively.

k=Wkxv=Wvxk = W_kx\\ v = W_vxk=Wk​xv=Wv​x

DeepSeek's Trick

Force this input vector transformation to key-values to be low rank. i.e

Instead of going from dmodelβ†’nhΓ—dhead d_{model} \rightarrow n_h \times d_{head} dmodel​→nh​×dhead​, we do dmodelβ†’ldimandldimβ†’nhΓ—dheadd_{model} \rightarrow l_{dim} \quad \text{and} \quad l_{dim} \rightarrow n_h \times d_{head}dmodel​→ldim​andldim​→nh​×dhead​. Where lll will be the dimension of the latent vector when going from xβ†’k,vx \rightarrow k,vxβ†’k,v. And instead of caching k,vk,vk,v we cahce lower dimensional vector lll. x∈RLΓ—dmodelx\in \mathcal{R}^{L \times d_{model}}x∈RLΓ—dmodel​

How we do this mathematically

lk=xWlk=lkWlkl_k = xW^l\\ k = l_kW^k_llk​=xWlk=lk​Wlk​

Similiary for value

lv=Wlxv=Wlvlvl_v = W^lx\\ v = W^v_l l_vlv​=Wlxv=Wlv​lv​

And this basically means that big matrix such as Wk∈R(nhΓ—dhead)Γ—dmodelW_k \in \mathcal{R}^{(n_h \times d_{head}) \times d_{model}}Wkβ€‹βˆˆR(nh​×dhead​)Γ—dmodel​ and Wv∈R(nhΓ—dhead)Γ—dmodelW_v \in \mathcal{R}^{(n_h \times d_{head}) \times d_{model}}Wvβ€‹βˆˆR(nh​×dhead​)Γ—dmodel​ has been decomposed into lower rank matrices.

But just caching lk,lvl_k,l_vlk​,lv​ would mean that during inference, we would have to waste some inference compute to get k,vk,vk,v from lk,lvl_k,l_vlk​,lv​ by up-projecting.

Another Clever trick here

Instead of up-projecting from latent to actual and value. We can merge that up-projection for key with q matrix and for v, we can merge the up-projection with the output linear projection layer.

==============================

The reason low-rank compression is so effective is because there’s plenty of information overlap between what different attention heads need to know about. If we used low-rank compression on the key and value vectors of individual heads instead of all keys and values of all heads stacked together, the method would simply be equivalent to using a smaller head dimension to begin with and we would get no gain. Exploiting the fact that different heads need access to the same information is essential for the mechanism of multi-head latent attention.

Methods such as grouped-query attention exploit the possibility of the same overlap, but they do so ineffectively by forcing attention heads that are grouped together to all respond similarly to queries. In other words, information sharing becomes coupled to having identical behavior in some restricted sense, a clearly undesirable property. Low-rank compression, on the other hand, allows the same information to be used in very different ways by different heads. In theory, this could even have beneficial regularizing effects on training, and DeepSeek reports finding such effects in their technical reports.

I see this as one of those innovations that look obvious in retrospect but that require a good understanding of what attention heads are actually doing to come up with. Once you see the approach, it’s immediately obvious that it cannot be any worse than grouped-query attention and it’s also likely to be significantly better. However, coming up with the idea of trying this is another matter.