Skip to content

llama : fix non-causal mask for gemma 3 #12615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 29, 2025

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Mar 27, 2025

Fix #12433

Related to this comment: #12181 (comment)


For example, provided that we have a cache of 10 tokens, 3 tokens already inside the cache and we are about to process 4 more tokens

Causal mask (with llama_set_causal_attn(true)):

xxxx------
xxxxx-----
xxxxxx----
xxxxxxx---

For gemma 3, what we want is (with llama_set_causal_attn(false)):

xxxxxxx---
xxxxxxx---
xxxxxxx---
xxxxxxx---

To visualize the mask, insert the code below inside llm_graph_input_attn_kv_unified::set_input():

printf("\n\n\n\n");
printf("self_kq_mask.shape %lld, %lld\n", self_kq_mask->ne[0], self_kq_mask->ne[1]);
printf("n_tokens: %lld, n_kv: %lld\n", n_tokens, n_kv);
for (int row = 0; row < n_row; row++) {
    for (int i = 0; i < std::min(max_tok_display, (int)self_kq_mask->ne[0]); i++) {
        printf(data[row*self_kq_mask->ne[0] + i] == 0.0f ? "x" : ".");
    }
    printf("\n");
}
printf("\n\n\n\n"); // GGML_ABORT("test");

@ngxson ngxson requested a review from ggerganov March 27, 2025 17:10
@@ -403,12 +403,14 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask || self_kq_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
if (cparams.use_past_tokens()) {
Copy link
Member

@ggerganov ggerganov Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an alternative solution, that does not require to introduce a new attention type:

When causal_attn == false, all tokens in the current batch attend to each other, and also to everything that is before them in the KV cache.

So this check here has to be removed and the 2 branches can be merged into a single implementation. It would require an extra loop over the batch of tokens when causal_attn == false, to find the token with the minimum position in the batch.

It might be a bit more complicated to implement, but I think it would be the better option, since it will not involve changing the interface.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would require an extra loop over the batch of tokens when causal_attn == false, to find the token with the minimum position in the batch.

Actually, this is probably not necessary - sorry for the confusion.

The main point is that the second use-case in the OP is not relevant:

Non-causal mask:
---xxxx---
---xxxx---
---xxxx---
---xxxx---

This is never used. For example, when running embeddings models, the entire sequence has to be passed as a single ubatch, and KV cache is not used.

WDYT?

Copy link
Collaborator Author

@ngxson ngxson Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right @ggerganov . On second thought, the case ---xxxx--- is never actually used since we always clear the KV cache before running non-causal embeddings to be able to use the full KV cache (or don't even use cache as you said)

So yeah I agree that the API introduced in this PR is redundant, I will extend the non-causal mask to include tokens in the past instead.

Btw, I also realized that my current logic may have produced this wrong mask:

xxxxxxxxxx
xxxxxxxxxx
xxxxxxxxxx
xxxxxxxxxx

Copy link
Collaborator Author

@ngxson ngxson Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov sorry for pinging again, but after having a deeper look, I'm now having doubt: should we force non-causal to use KV cache?

My current idea is to make non-causal to always use KV cache, so that we can have an unify way of producing the mask. Before, I thought that llama_set_causal_attn(ctx.lctx, false) only modify the mask, but turns out it also disallow the batch to observe tokens in cache.

Copy link
Member

@ggerganov ggerganov Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently have 2 main paths:

  • llm_graph_input_attn_no_cache::set_input
    Used by embeddings models like BERT that don't require KV cache. In this case, we always have causal_attn == false. There is also handling of the case for causal_attn == true, which technically can allow you to run non-embedding models (e.g llama) with causal attn, but none of the examples uses this, because it is very slow (i.e. you have to process the entire context for each token).
    Anyway, this path does not require any changes.

  • llm_graph_input_attn_kv_unified::set_input
    Here we have 2 cases:

    • causal_attn == true: uses the KV cache (does not need any changes)
    • causal_attn == false: does not use the KV cache (basically the same as the "no_cache" path above). This case here has to be updated to start using the KV cache. If I am not mistaken, you can simply copy-paste the first case (i.e. causal_attn == true) and remove the position checks.

Let me know if you have doubts - I'm also not 100% sure that the above is entirely correct.

Copy link
Collaborator Author

@ngxson ngxson Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llm_graph_input_attn_kv_unified::set_input

causal_attn == false: does not use the KV cache (basically the same as the "no_cache" path above). This case here has to be updated to start using the KV cache. If I am not mistaken, you can simply copy-paste the first case (i.e. causal_attn == true) and remove the position checks.

Yes this is what I discovered. I also include the code I used to visualize the mask in this PR's description.

The problem is that in this case, the mask layout seems to be different (i.e. the tensor shape). Using the visualization code above, it prints out this mask for a batch of 3 tokens:

self_kq_mask.shape 32, 64
n_tokens: 3
xxxxxxxxx...........
....................
....................
....................

So I expect that somewhere in the cgraph, we will reshape this tensor into a 3x3 matrix (though I still haven't found this code yet)

Another possibility is that the current code on master branch could be wrong. What I expected to see was:

self_kq_mask.shape 32, 64
n_tokens: 3
xxx.................
xxx.................
xxx.................
....................

Could you clarify what is the shape of self_kq_mask? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on master this is broken. The only example that uses this path is currently Gemma vision and nothing else should go there.

@ngxson ngxson force-pushed the xsn/llama_set_attn_type branch from 6a7eea7 to bed4c73 Compare March 28, 2025 21:14
@ngxson ngxson changed the title llama : add llama_set_attn_type API llama : fix non-causal mask for gemma 3 Mar 28, 2025
@ngxson ngxson requested a review from ggerganov March 28, 2025 21:15
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com./ggml-org/llama.cpp/pull/12615
Copy link
Collaborator Author

@ngxson ngxson Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking, will it be a good idea to regroup some debugging tools / functions into a header called llama-debug.h (header-only) ? This can be included during development, and will be removed when people want to merge code to master

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sounds good. We can have debugging code remain and guard it with environment variables similar to GGML_SCHED_DEBUG.

// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com./ggml-org/llama.cpp/pull/12615
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sounds good. We can have debugging code remain and guard it with environment variables similar to GGML_SCHED_DEBUG.

@ngxson ngxson merged commit af6ae1e into ggml-org:master Mar 29, 2025
48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Eval bug: Gemma3 <unused32> spam
2 participants