-
Notifications
You must be signed in to change notification settings - Fork 11.5k
llama : fix non-causal mask for gemma 3 #12615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/llama-graph.cpp
Outdated
@@ -403,12 +403,14 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { | |||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { | |||
if (self_kq_mask || self_kq_mask_swa) { | |||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. | |||
if (cparams.causal_attn) { | |||
if (cparams.use_past_tokens()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is an alternative solution, that does not require to introduce a new attention type:
When causal_attn == false
, all tokens in the current batch attend to each other, and also to everything that is before them in the KV cache.
So this check here has to be removed and the 2 branches can be merged into a single implementation. It would require an extra loop over the batch of tokens when causal_attn == false
, to find the token with the minimum position in the batch.
It might be a bit more complicated to implement, but I think it would be the better option, since it will not involve changing the interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would require an extra loop over the batch of tokens when causal_attn == false, to find the token with the minimum position in the batch.
Actually, this is probably not necessary - sorry for the confusion.
The main point is that the second use-case in the OP is not relevant:
Non-causal mask:
---xxxx---
---xxxx---
---xxxx---
---xxxx---
This is never used. For example, when running embeddings models, the entire sequence has to be passed as a single ubatch, and KV cache is not used.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you're right @ggerganov . On second thought, the case ---xxxx---
is never actually used since we always clear the KV cache before running non-causal embeddings to be able to use the full KV cache (or don't even use cache as you said)
So yeah I agree that the API introduced in this PR is redundant, I will extend the non-causal mask to include tokens in the past instead.
Btw, I also realized that my current logic may have produced this wrong mask:
xxxxxxxxxx
xxxxxxxxxx
xxxxxxxxxx
xxxxxxxxxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ggerganov sorry for pinging again, but after having a deeper look, I'm now having doubt: should we force non-causal to use KV cache?
My current idea is to make non-causal to always use KV cache, so that we can have an unify way of producing the mask. Before, I thought that llama_set_causal_attn(ctx.lctx, false)
only modify the mask, but turns out it also disallow the batch to observe tokens in cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently have 2 main paths:
-
llm_graph_input_attn_no_cache::set_input
Used by embeddings models like BERT that don't require KV cache. In this case, we always havecausal_attn == false
. There is also handling of the case forcausal_attn == true
, which technically can allow you to run non-embedding models (e.g llama) with causal attn, but none of the examples uses this, because it is very slow (i.e. you have to process the entire context for each token).
Anyway, this path does not require any changes. -
llm_graph_input_attn_kv_unified::set_input
Here we have 2 cases:causal_attn == true
: uses the KV cache (does not need any changes)causal_attn == false
: does not use the KV cache (basically the same as the "no_cache" path above). This case here has to be updated to start using the KV cache. If I am not mistaken, you can simply copy-paste the first case (i.e.causal_attn == true
) and remove the position checks.
Let me know if you have doubts - I'm also not 100% sure that the above is entirely correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llm_graph_input_attn_kv_unified::set_input
causal_attn == false: does not use the KV cache (basically the same as the "no_cache" path above). This case here has to be updated to start using the KV cache. If I am not mistaken, you can simply copy-paste the first case (i.e. causal_attn == true) and remove the position checks.
Yes this is what I discovered. I also include the code I used to visualize the mask in this PR's description.
The problem is that in this case, the mask layout seems to be different (i.e. the tensor shape). Using the visualization code above, it prints out this mask for a batch of 3 tokens:
self_kq_mask.shape 32, 64
n_tokens: 3
xxxxxxxxx...........
....................
....................
....................
So I expect that somewhere in the cgraph, we will reshape this tensor into a 3x3 matrix (though I still haven't found this code yet)
Another possibility is that the current code on master
branch could be wrong. What I expected to see was:
self_kq_mask.shape 32, 64
n_tokens: 3
xxx.................
xxx.................
xxx.................
....................
Could you clarify what is the shape of self_kq_mask
? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, on master
this is broken. The only example that uses this path is currently Gemma vision and nothing else should go there.
6a7eea7
to
bed4c73
Compare
// xxxxx----- | ||
// xxxxx----- | ||
// xxxxx----- | ||
// To visualize the mask, see https://github.com./ggml-org/llama.cpp/pull/12615 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking, will it be a good idea to regroup some debugging tools / functions into a header called llama-debug.h
(header-only) ? This can be included during development, and will be removed when people want to merge code to master
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sounds good. We can have debugging code remain and guard it with environment variables similar to GGML_SCHED_DEBUG
.
// xxxxx----- | ||
// xxxxx----- | ||
// xxxxx----- | ||
// To visualize the mask, see https://github.com./ggml-org/llama.cpp/pull/12615 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sounds good. We can have debugging code remain and guard it with environment variables similar to GGML_SCHED_DEBUG
.
Fix #12433
Related to this comment: #12181 (comment)
For example, provided that we have a cache of 10 tokens, 3 tokens already inside the cache and we are about to process 4 more tokens
Causal mask (with
llama_set_causal_attn(true)
):For gemma 3, what we want is (with
llama_set_causal_attn(false)
):To visualize the mask, insert the code below inside
llm_graph_input_attn_kv_unified::set_input()
: