This website uses cookies

Read our Privacy policy and Terms of use for more information.

Causal attention is the technique that powers many impactful AI models today. If the model is an autoregressive transformer, causal attention is what comes along with this type of model. It’s the rule that made LLMs stable, scalable, and effective, predicting the future one step at a time from the past. In the causal attention mechanism, each token only attends to previous tokens, never looking forward.

This sounds like the perfect order, but what if future tokens have much to “say” at the current moment? As we move on to models that work on a higher level of reasoning, we also need them to capture the whole global context.

Today we are diving into an interesting idea that gives a clue to how we can expand causal attention – it’s about giving tokens the opportunity to attend to information hidden in future tokens, not only the past. ByteDance’s Causal Attention with Lookahead Keys (CASTLE) supports this idea by dynamically updating the keys as more tokens are processed. Another interesting approach from the University of Sydney and Shanghai Jiao Tong University, called future-aware causal masks, shows that context from the future works very well for vision tasks in Vision-Language Models (VLMs).

We’ll cover the changes that come with these two approaches and also explore how “causal” is being rethought in another sense – bringing in real cause-effect reasoning.

Let’s start from the basics!

In today’s episode, we will cover:

  • Causal attention: The basics

  • What is CASTLE?

    • Changing the standard causal attention workflow

    • Mechanisms under the hood

    • Results of CASTLE

    • Advantages vs. disadvantages

  • Causal attention and Vision-Language Models

    • Future-aware causal masks

    • How does Light Future Aware Attention work?

    • Benefits and possible issues

  • Other methods: Causal as cause-effect relationships

  • Conclusion

  • Sources and further reading

Causal attention: The basics

A strong backbone of modern generative AI today is a “collaboration” of transformers that provide the underlying architecture and autoregression that defines the strategy for generating text.

Autoregressive generation means a model produces text one token at a time, left to right, and each new token is predicted based only on the tokens that came before. Autoregressive models have proven themselves to be effective and reliable thanks to the following aspects:

  • Autoregression mirrors how humans naturally produce and interpret sentences, one-by-one.

  • The model learns by predicting the next token during training, which aligns perfectly with how it will be used during generation. This makes training and inference consistent and stable.

  • Also, by conditioning on everything generated so far, the model keeps its output contextually consistent.

  • Transformers in autoregressive mode can train efficiently in parallel, predicting all next tokens in a batch at once.

  • And finally, the main advantage that led to many huge AI breakthroughs – autoregressive models scale extremely well with more data and parameters.

How does this autoregression strategy look from the inner side?

Autoregressive transformers use standard causal attention mechanism where each token can only attend to its past and present tokens, never future ones. For example, for predicting the 5th word, the model only attends to words 1–4.

This is enforced with a causal mask, which blocks access to tokens that come later in the sequence.

Image Credit: “Causal Attention with Lookahead Keys” paper

Here is how standard causal attention mechanism works step-by-step:

  • Each token has a query, key, and value (QKV).

  • When generating the next token, the model uses the new query to compare against all the keys it has so far.

  • The comparison produces attention scores, which are turned into weights.

  • Those weights are then applied to the values, giving the output for the next token.

  • Each key only represents the information available up to its own position in the sequence. Keys never get updated later.

  • At inference time, the model generates one token at a time, feeding each new token back into itself.

In this setup, each token’s queries, keys, and values (QKV) are computed once from the token’s representation and remain fixed, so the token encodes only information from earlier tokens.

This is a well-established mechanism working over years, but what if there is important information that shows up later in a sentence, that a model can miss? In some cases, causal mask can limit a model’s ability to understand the bigger picture of a sentence or passage.

Do you remember BERT (Bidirectional Encoder Representations from Transformers) models? They process tokens from both sides – left-to-right and right-to-left, and this kind of approach brings a deeper context understanding, developing a richer understanding of word meaning and making models more context-aware.

So what if, in a way, we allowed a glimpse of future tokens for autoregressive models?

In this case, different researchers came to the conclusion that we need to rethink standard causal attention, and giving a model the opportunity to attend to future tokens and context appeared not to be a bad choice.

A recent paper from ByteDance Seed inspired us to take a deeper look at this idea and explore where else this concept has been successfully applied. But first, let’s start with this ByteDance’s new method – Causal Attention with Lookahead Keys (CASTLE).

What is CASTLE?

Changing the Standard Causal Attention Workflow

As we said, in standard causal attention, queries, keys, and values (QKV) are fixed once they are created. CASTLE approach to causal attention works exactly in this area and involves updating the keys (K) as more tokens are processed. These updated keys are called lookahead keys:

  • They belong to earlier tokens.

  • But they integrate information from tokens that appear later.

For example, when the model is generating the 4th token, the key for token 1 already knows about tokens 1–4, and when generating the 6th token, it expands again to include tokens 1–6.

Image Credit: “Causal Attention with Lookahead Keys” paper

As context unfolds, the representation of earlier tokens can grow richer, carrying hints from the words that follow. It is like constant updating of tokens memory.

But important note here is that CASTLE still respects the autoregressive rule:

  • At step (t+1), the updated key for token s only includes information up to token t.

  • No token ever attend to information from the future.

  • Queries and values are also limited to the context seen so far.

(Note: t means the current generation step, or the token index being generated, and s is an earlier token in the sequence, it can be any token before t.)

A fair question – why are only the keys updated? The reasons are also fair.

  • Queries (Q) are only used once, when generating the next token. Updating past queries would have no effect.

  • Keys (K), however, are reused by all future queries. Updating them improves the quality of attention scores for every later step.

  • Updating values (V) could also help, but it would be more computationally expensive.

CASTLE focuses on keys because they allow a clean mathematical trick that keeps training efficient.

Mechanisms Under the Hood

CASTLE uses hybrid attention mechanism: each token has both a causal key (static) and a lookahead key (dynamic).

Causal keys are just like in standard causal attention, that we’ve explained above. Lookahead keys are the new addition to the process. To compute them, CASTLE follows this attention mechanism:

  • Each token s gathers information from tokens that came after it (s+1 … t). For example, the lookahead key of token 3, while generating token 6, includes information from tokens 4 and 5 as well, and it can attend across all tokens 1–5.

  • A mask ensures no token ever attends into the future (beyond t).

  • The model doesn’t always absorb all later tokens (the way a softmax function would). CASTLE uses a sigmoid function, that allows flexibility – a token may or may not take information from each later token.

  • As they keep renewing as the sequence grows, lookahead keys for the same token are different at step t vs. step t+1.

Here is how CASTLE combines causal and lookahead keys:

Image Credit: “Causal Attention with Lookahead Keys” paper

  • When generating the next token, the query from the current token is compared against both causal and lookahead keys.

  • This produces two sets of scores: one reflecting static past information, the other reflecting updated context.

  • Before the scores from lookahead keys are merged with causal key scores, they are passed through SiLU (Sigmoid Linear Unit) gating function. This prevents noisy, irrelevant or outdated past tokens from dominating, regulating the degree to which each past token should be forgotten.

  • Then scores from lookahead and causal keys are combined and the result is passed through softmax (the function that turns a vector of numbers into a probability distribution) to ensure the model produces valid attention weights.

  • Finally, the resulting attention weights are applied to the values, producing the output for this step.

We described the recurrent version of CASTLE, where lookahead keys are updated step-by-step. However, CASTLE can be rewritten in a parallel formulation, giving the same result but making large-scale pretraining feasible. This form recasts recurrent updates into masked matrix multiplications, exploiting low-rank decomposition to cut down cost and memory use. As a result, the model computes outputs for the whole sequence in one shot, and there is no need to update and store keys for every token.

For efficient inference, researchers introduced UQ-KV cache that is similar to the standard KV cache but holds the information needed for lookahead keys. They are not recomputed from scratch – CASTLE updates them recursively, reusing cached results.

Also, as with normal transformers, CASTLE can be extended to multi-head attention. Each head computes its own version of causal and lookahead attention, and the results are concatenated and projected back into the model’s hidden dimension.

Results of CASTLE

When tested on language modeling benchmarks, CASTLE shows consistent improvements over standard causal attention. First of all, it reduces training and validation loss and perplexity at every scale:

  • On the small model (160M parameters), the gains are modest: perplexity drops from 16.41 to 16.32.

  • As models scale up, the perplexity benefits grow clearer: for the medium model (350M) falls from 14.00 to 13.67 perplexity, and for the large model (750M) – from 12.27 to 11.84.

  • At the XL scale (1.3B parameters), the benefits are the most notable: CASTLE cutting perplexity from 11.31 to 10.92.

Image Credit: “Causal Attention with Lookahead Keys” paper

On downstream benchmarks, CASTLE models also surpass their baselines:

  • On ARC-C (reasoning benchmark) scores climb from 33.79% to 35.32% (0-shot) and from 35.58% to 39.08% (5-shot).

  • On BoolQ, accuracy rises from 61.07% to 62.6% in the XL model, marking stronger reading comprehension.

  • On Winogrande, which tests commonsense reasoning, accuracy improves from 54.06% to 56.59% (0-shot) and from 52.72% to 58.33% (5-shot).

Image Credit: “Causal Attention with Lookahead Keys” paper

So CASTLE improves performance on a variety of downstream tasks, not just text prediction, with stronger reasoning and generalization. Importantly, the effect is particularly visible in larger models, which have the capacity to fully capture richer, evolving context.

Advantages vs. Disadvantages

Now, let’s sum up what makes CASTLE attention mechanism stand out.

  • CASTLE keeps models autoregressive but gives them a global view of the context.

  • It balances stability from the frozen causal keys with adaptability from the evolving lookahead keys.

  • Lower perplexity and loss → the model predicts text more confidently.

  • CASTLE models achieve higher accuracy, especially in reasoning and commonsense tasks.

  • A parallel formulation allows CASTLE to be trained with complexity similar to standard attention, which makes it practical for large-scale pretraining.

  • CASTLE models don’t require more parameters than the baseline transformer.

But obviously, there are factors that drag down CASTLE attention mechanism’s effectiveness:

  • Additional computational cost compared to standard attention.

  • It is harder to implement CASTLE, because it needs special optimizations like UQ-KV cache mechanism.

  • Small benefit for small models.

  • Its behavior on multi-trillion-token training, used now in frontier models, isn’t yet fully validated.

Anyway, CASTLE broadens our view of causal attention for autoregressive models, introducing new rules for attending to information from future without breaking the old ones. But it isn’t the first approach exploring this direction. One interesting piece of research before ByteDance was introduced by developers from the University of Sydney and Shanghai Jiao Tong University, exploring how causal attention can be applied to Vision-Language Models (VLMs).

Causal Attention and Vision-Language Models

Standard causal attention works well for text as it is sequential, but VLMs also process visual information which is holistic, with different parts often needing to be seen together. Researchers from the University of Sydney and Shanghai Jiao Tong University noticed the tension between causal attention and vision encoders. Firstly, the vision encoder compresses the whole image into a set of tokens, which means each vision token already carries global image context. Then causal attention forces these vision tokens to behave like text tokens, which limits useful information from images that can be shared across tokens.

The researchers found that if you relax the causal mask for vision tokens, so they can use information from future image regions, models perform better on vision-language reasoning tasks. At the same time, they argue that text still needs strict causality. So the researchers proposed future-aware causal masks, which takes these two points into account.

Future-aware Causal Masks

They designed three specialized masks:

  • Full Future Mask (Mf): Visual tokens can attend to all future tokens (visual + text). It helps with temporal multi-image tasks, like navigation or state changes across image sequences.

  • Visual-to-Visual Mask (Mv2v): Visual tokens can attend future visual tokens only, not text. This one improves visual relation tasks, including detecting object changes or relationships between frames.

  • Visual-to-Textual Mask (Mv2t): Visual tokens can attend only future text tokens. This mask helps in text-rich image tasks (OCR-VQA, TextVQA) where visual patches need to connect to words embedded in images.

Letting vision tokens attend the full future context is like a dream, but it gets crushed by much slower inference, especially during autoregressive decoding where tokens are generated one by one. So the researchers proposed a compromise – Light Future Aware Attention.

Image Credit: “Rethinking Causal Mask Attention for Vision-Language Inference” paper

How does Light Future Aware Attention work?

VLMs are usually run in two phases:

  • Prefill stage, when they process the whole input at once.

  • Decoding stage, when models generate output tokens step by step.

In Light Future Aware Attention, future access happens only in the prefill stage. This means that future information is compressed into earlier tokens, and by the time decoding starts, the model can stick to the normal causal mask without extra overhead. Here is how it happens step-by-step:

Image Credit: “Rethinking Causal Mask Attention for Vision-Language Inference” paper

  • In autoregressive models, the very first tokens act as “sinks” that gather a lot of attention (attention sink phenomenon). This feature is used as advantage for the method.

  • During the prefill stage, the model applies a Kernel pooling operation that summarizes the future context into a small representation.

  • These compressed future representations are merged back into earlier tokens (the prefix), and these past tokens now carry hints about the future.

  • When the model generates outputs, it still uses a strictly causal left-to-right mask. As the past tokens already hold some future context, the model benefits without breaking autoregression.

Now here is what we have for Light Future Aware Attention effectiveness.

Benefits and possible issues

  • Compared to the Full Future Mask, Light Future Aware Attention gives a 2–3× speedup while keeping most of the accuracy gains (just take a look at the results below).

Image Credit: “Rethinking Causal Mask Attention for Vision-Language Inference” paper

  • In general, Light Future Aware Attention keeps the inference fast, while preserving the benefits of rich context.

  • Models perform better on vision tasks, like temporal reasoning, visual relation inference, and text-rich QA.

  • The final mask remains strictly causal, which means consistency.

As for the limitations of future aware attention, there are several important points:

  • Similar to CASTLE method, it is a complex approach that requires additional tricks, such as kernel pooling and prefix merging mechanisms.

  • Choosing the right strategy (Full, Visual-to-Visual, Visual-to-Textual) plus whether to merge also adds manual complexity.

  • Text-dominant or retrieval-heavy tasks don’t get improvement (but, honestly, future aware masks are primary invented for improving vision tasks, so maybe that’s not a serious issue).

  • Future aware attention is still an approximation: By compressing future context into prefix tokens, the model doesn’t have full access to raw future tokens.

Despite these issues, this method shows that VLMs really benefit from expanding standard causal attention to the future aware attention. And the main thing is the balance.

Other methods: Causal as cause-effect relationships

As we are expanding the understanding of causal attention today, we’ll add that “causal” is not always about “one-by-one”, “left-to-right”, and “step-by-step”.

If we look more broadly, we see that “causal” also relates to the topic of Causal AI, which emphasizes the importance of capturing cause-effect relationships in data. Another line of research explores precisely how to upgrade causal attention, literally meaning paying attention to the right cause-effect patterns. So let’s briefly look at them too:

  • Causal Attention Tuning (CAT) is a way to train LLMs to not just follow surface-level patterns but actually pay attention to what causes what and prevent wrong correlations in the data. CAT directly injects causal priors into the attention maps during training, guiding attention toward causes. It uses human and LLM-generated signals to build a causal graph of large-scale “causal hints” for training data.

    The second part is causal constraint attention training. CAT adds an extra re-attention loss that ensures causal tokens get at least as much attention as non-causal ones. This is done by comparing average attention weights across all layers and heads, then nudging them toward the causal adjacency matrix. The final training loss combines the usual next-token prediction loss with this causal loss, so the model learns to focus on the right causal structure.

Image Credit: CAT original paper

  • Learning to Focus (LeaF) method, on the other hand, figures out which tokens are confounders – distracting but non-causal, by comparing a teacher and student model. Those confounding tokens get pruned during distillation, so the student learns to attend only to meaningful signals.

  • AutoHFormer is a Transformer for time-series forecasting designed to balance causality, efficiency, and multi-scale pattern recognition. It splits sequences into blocks for parallel predictions, then refines them inside each block.

    The key technique are:

    • Dynamic windowed attention, which uses learnable sliding windows with exponential decay to keep computation sub-quadratic and strictly causal.

    • Adaptive temporal encoding, which blends sinusoidal patterns for short-term dynamics with learnable decay for long-term trends.

  • “Transforming Causality: Transformer-Based Temporal Causal Discovery with Prior Knowledge Integration” paper proposes a framework using multi-layer Transformer to learn causal graphs and time lags from time-series data. It has two main parts:

    • The causality-aware forecaster: It’s a Transformer model trained to predict future values, also picking up on which variables influence each other. If experts already know that certain connections don’t exist, they can apply attention masks so the model avoids learning those false links.

    • The causal graph extractor: After training, it looks at how sensitive predictions are to each input via gradients and uses this to build a causal graph that shows the relationships and time lags between variables. To reduce false connections, it adds a human-in-the-loop step, so users can review the graph, correct mistakes, and feed those corrections back into the model.

Image Credit: “Transforming Causality: Transformer-Based Temporal Causal Discovery with Prior Knowledge Integration” paper

Conclusion

Causal attention has been the backbone of autoregressive transformers, but it doesn’t have to stay frozen in its original form. From CASTLE’s lookahead keys to future-aware masks in VLMs, we’re seeing creative ways to expand the model’s global context while keeping the structure intact. Letting models borrow a glimpse of the future can improve reasoning, comprehension, and performance across tasks.

By carefully relaxing the rules, researchers are finding that models can become more context-aware without losing the reliability that made them powerful in the first place. After all, this attention is what keeps evolving together with the models, so this small shifts may end up in next breakthroughs in language and multimodal models.

Also, do not forget about the another edge of causal attention – cause-effect reasoning that makes models aware what is happening and why, also broadening their global picture. Fascinating development!

Sources and further reading

Resources from Turing Post

Reply

Avatar

or to participate

Keep Reading