AI 101: What is XQuant?

Compute is not a big deal for LLMs now, but memory is. Explore how a new XQuant method and its XQuant-CL variation can save the memory use up to 12 times

We want LLMs to be as fast and accurate as they can, so chasing high inference speed is an important and non-trivial job for developers. Since LLMs run on hardware, there are two things you can rely on for processing speed:

  1. Computation – how many math operations the hardware must perform.

  2. Memory – how much data needs to be moved in and out of memory.

The thing is that LLMs are more memory-bound than compute-bound. We are in the era when generating a token requires only small matrix-vector multiplications, but involves loading large amounts of data from memory, and this process is constantly repeated. For LLM inference, memory is a bigger issue than math (hardware compute), which gives us a lot of research trying to overcome these memory limits.

One of them is a recent study from UC Berkeley, FuriosaAI, ICSI, and LBNL on XQuant – a method that can cut memory use up to 12 times by adding a little bit more compute cost. XQuant and its variation XQuant-CL bypass typical techniques like KV cache quantization that usually lead to accuracy drop.

So today, starting from KV caching, we will follow the journey of the XQuant creators to see how this new technique actually works and what it is really capable of. Join us!

In today’s episode, we will cover:

  • KV cache and KV caching

  • KV cache compression and its limitations

  • What is XQuant?

  • How does XQuant work?

  • XQuant-CL: A better version

  • XQuant working with Group-Query Attention (GQA)

  • Performance gains

  • The benefits of XQuant and XQuant-CL

  • Not without limitations

  • Conclusion

  • Sources and further reading

KV cache and KV caching

Today we focus on how to deal with large memory use. For short texts, the model weights are the main memory load, while for longer texts, the bigger problem is the Key-Value (KV) cache. It stores the representation of the whole sequence for self-attention to help the model track the full context, growing linearly with input length, which causes limitations.

One of the most popular method is to compress the KV cache through quantization. But first of all, let’s look at what KV cache exactly is.

In transformer models, every time a token is processed, the attention mechanism needs Keys (K) and Values (V) to figure out how much each previous token should influence the current one. Each token is represented by these three vectors:

  • Key (K): Encodes what information a token represents.

  • Value (V): Encodes the actual content to be passed along if that token is selected as relevant.

  • Query (Q): Encodes the features used to compute similarity scores with Keys (K).

The model’s attention mechanism compares a token’s Query with all previous Keys. The resulting similarity scores are used to weight the corresponding Values and to produce a context-aware representation of the token.

During autoregressive text generation, the model typically recomputes all K and V matrices from scratch for every new token. However, this is redundant and computationally expensive, since the cost grows with the length of the sequence.

Overall, the KV cache is a memory structure that stores the previously computed Keys and Values, so they don’t need to be recalculated at each step of decoding.

And what is KV caching?

KV caching is the process of using the KV cache during inference. Instead of recalculating attention against the entire sequence of tokens every time, the model only does new work for the new tokens:

  • It computes and stores the Keys and Values for the first tokens in GPU memory.

  • Then, at each step, the model retrieves stored Keys and Values, adds the new token’s Keys and Values, and computes attention only for the new token’s Query.

Image Credit: Mastering LLM Techniques: Inference Optimization, NVIDIA blog

In this case, the cache grows with each step and attention is computed much faster.

As we have already mentioned KV cache grows linearly with input length, exploding memory usage. There are two common ways to compress KV cache.

KV cache compression and its limitations

The first method is quantization. Compressing the KV cache using quantization means representing it with fewer bits. Thanks to this, more tokens can be cached, enabling longer context windows and reducing memory use. However, when the quantization goes too low, like 2–3 bits, model accuracy drops sharply.

Another approach is to shrink the KV cache using low-rank decompositioncompressing the KV cache into smaller spaces. This means breaking large matrices into smaller, compressed forms. However, this method also comes at an accuracy cost: it risks throwing away important information and is mathematically heavy because of compression and decompression that happens across different ranks.

Another interesting idea came from OpenMachine and is called Slim Attention. It stores only Keys and uses math tricks to recover Values from them. But this requires unstable matrix inverses and also doesn’t work well with Rotary Position Embeddings (RoPE) or Grouped-Query Attention (GQA).

So all these methods point out a big issue:

  • GPUs will keep getting faster at computation, but memory growth will lag behind;

  • Plus, accuracy drops with methods like KV cache quantization or low-rank decomposition.

That’s why we need something that can save memory better, be more accurate and easier to compress. And here is one of the freshest solutions.

What is XQuant?

XQuant is a new method proposed by a group of researchers from UC Berkeley, FuriosaAI, ICSI, and LBNL that trades some extra computation for much less memory usage when running LLMs and breaks, as the researchers call it, the Memory Wall. They follow the idea: GPUs can theoretically do huge amounts of math, but they can’t feed themselves data fast enough; so the logical solution is to reduce memory operations, even if it costs additional compute.

Instead of storing the usual KV cache, XQuant quantizes and stores only the layer input activations, called X. Then, when needed, it rematerializes (recalculates) Keys and Values from X on the fly during inference. This method allow not to store everything in memory, but throw away some data and recompute it later when needed. Let’s look at the inner side of this process. It’s fascinating

Join Premium members from top companies like Microsoft, Google, Hugging Face, a16z, Datadog plus AI labs such as Ai2, MIT, Berkeley, .gov, and thousands of others to really understand what’s going on with AI. Simplify your learning journey 👆🏼

Reply

or to participate.