AI 101: Rethinking Causal Attention

How Causal Attention with Lookahead Keys (CASTLE) and future-aware casual masks reshape the strict left-and-right order of autoregressive models, plus attention in cause-effect relationships.

Causal attention is the technique that powers many impactful AI models today. If the model is an autoregressive transformer, causal attention is what comes along with this type of model. It’s the rule that made LLMs stable, scalable, and effective, predicting the future one step at a time from the past. In the causal attention mechanism, each token only attends to previous tokens, never looking forward.

This sounds like the perfect order, but what if future tokens have much to “say” at the current moment? As we move on to models that work on a higher level of reasoning, we also need them to capture the whole global context.

Today we are diving into an interesting idea that gives a clue to how we can expand causal attention – it’s about giving tokens the opportunity to attend to information hidden in future tokens, not only the past. ByteDance’s Causal Attention with Lookahead Keys (CASTLE) supports this idea by dynamically updating the keys as more tokens are processed. Another interesting approach from the University of Sydney and Shanghai Jiao Tong University, called future-aware causal masks, shows that context from the future works very well for vision tasks in Vision-Language Models (VLMs).

We’ll cover the changes that come with these two approaches and also explore how “causal” is being rethought in another sense – bringing in real cause-effect reasoning.

Let’s start from the basics!

In today’s episode, we will cover:

  • Causal attention: The basics

  • What is CASTLE?

    • Changing the standard causal attention workflow

    • Mechanisms under the hood

    • Results of CASTLE

    • Advantages vs. disadvantages

  • Causal attention and Vision-Language Models

    • Future-aware causal masks

    • How does Light Future Aware Attention work?

  • Conclusion

  • Sources and further reading

Causal attention: The basics

A strong backbone of modern generative AI today is a “collaboration” of transformers that provide the underlying architecture and autoregression that defines the strategy for generating text.

Autoregressive generation means a model produces text one token at a time, left to right, and each new token is predicted based only on the tokens that came before. Autoregressive models have proven themselves to be effective and reliable thanks to the following aspects:

  • Autoregression mirrors how humans naturally produce and interpret sentences, one-by-one.

  • The model learns by predicting the next token during training, which aligns perfectly with how it will be used during generation. This makes training and inference consistent and stable.

  • Also, by conditioning on everything generated so far, the model keeps its output contextually consistent.

  • Transformers in autoregressive mode can train efficiently in parallel, predicting all next tokens in a batch at once.

  • And finally, the main advantage that led to many huge AI breakthroughs – autoregressive models scale extremely well with more data and parameters.

How does this autoregression strategy look from the inner side?

Autoregressive transformers use standard causal attention mechanism where each token can only attend to its past and present tokens, never future ones. For example, for predicting the 5th word, the model only attends to words 1–4.

This is enforced with a causal mask, which blocks access to tokens that come later in the sequence.

Image Credit: “Causal Attention with Lookahead Keys” paper

Here is how standard causal attention mechanism works step-by-step:

  • Each token has a query, key, and value (QKV).

  • When generating the next token, the model uses the new query to compare against all the keys it has so far.

  • The comparison produces attention scores, which are turned into weights.

  • Those weights are then applied to the values, giving the output for the next token.

  • Each key only represents the information available up to its own position in the sequence. Keys never get updated later.

  • At inference time, the model generates one token at a time, feeding each new token back into itself.

In this setup, each token’s queries, keys, and values (QKV) are computed once from the token’s representation and remain fixed, so the token encodes only information from earlier tokens.

This is a well-established mechanism working over years, but what if there is important information that shows up later in a sentence, that a model can miss? In some cases, causal mask can limit a model’s ability to understand the bigger picture of a sentence or passage.

Do you remember BERT (Bidirectional Encoder Representations from Transformers) models? They process tokens from both sides – left-to-right and right-to-left, and this kind of approach brings a deeper context understanding, developing a richer understanding of word meaning and making models more context-aware.

So what if, in a way, we allowed a glimpse of future tokens for autoregressive models?

In this case, different researchers came to the conclusion that we need to rethink standard causal attention, and giving a model the opportunity to attend to future tokens and context appeared not to be a bad choice.

A recent paper from ByteDance Seed inspired us to take a deeper look at this idea and explore where else this concept has been successfully applied. But first, let’s start with this ByteDance’s new method – Causal Attention with Lookahead Keys (CASTLE).

What is CASTLE?

Changing the Standard Causal Attention Workflow

As we said, in standard causal attention, queries, keys, and values (QKV) are fixed once they are created. CASTLE approach to causal attention works exactly in this area and involves updating the keys (K) as more tokens are processed. These updated keys are called lookahead keys:

Below we will discuss how CASTLE works, how similar idea is applied in Vision-Language Models and another angle of causal attention in terms of causality.

Join Premium members from top companies like Microsoft, Google, Hugging Face, a16z, Datadog plus AI labs such as Ai2, MIT, Berkeley, .gov, and thousands of others to really understand what’s going on with AI. Simplify your learning journey 👆🏼

Reply

or to participate.