Skip to content

Commit 8e752a7

Browse files
authored
llama : add check for KV cache shifts (#10401)
ggml-ci
1 parent a88ad00 commit 8e752a7

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

common/common.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) {
875875
return iparams;
876876
}
877877

878+
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
879+
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
880+
llama_free_model(model);
881+
return iparams;
882+
}
883+
878884
if (!params.control_vectors.empty()) {
879885
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
880886
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);

include/llama.h

+3
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,9 @@ extern "C" {
667667
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
668668
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
669669

670+
// Check if the context supports KV cache shifting
671+
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
672+
670673
//
671674
// State / sessions
672675
//

src/llama.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -18213,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
1821318213

1821418214
// apply K-shift if needed
1821518215
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
18216-
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
18216+
if (!llama_kv_cache_can_shift(&lctx)) {
1821718217
GGML_ABORT("Deepseek2 does not support K-shift");
1821818218
}
1821918219

@@ -20462,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
2046220462
llama_kv_cache_update_internal(*ctx);
2046320463
}
2046420464

20465+
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
20466+
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
20467+
}
20468+
2046520469
// deprecated
2046620470
size_t llama_get_state_size(struct llama_context * ctx) {
2046720471
return llama_state_get_size(ctx);

0 commit comments

Comments
 (0)