File tree 3 files changed +14
-1
lines changed
3 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) {
875
875
return iparams;
876
876
}
877
877
878
+ if (params.ctx_shift && !llama_kv_cache_can_shift (lctx)) {
879
+ LOG_ERR (" %s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n " , __func__);
880
+ llama_free_model (model);
881
+ return iparams;
882
+ }
883
+
878
884
if (!params.control_vectors .empty ()) {
879
885
if (params.control_vector_layer_start <= 0 ) params.control_vector_layer_start = 1 ;
880
886
if (params.control_vector_layer_end <= 0 ) params.control_vector_layer_end = llama_n_layer (model);
Original file line number Diff line number Diff line change @@ -667,6 +667,9 @@ extern "C" {
667
667
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
668
668
LLAMA_API void llama_kv_cache_update (struct llama_context * ctx);
669
669
670
+ // Check if the context supports KV cache shifting
671
+ LLAMA_API bool llama_kv_cache_can_shift (struct llama_context * ctx);
672
+
670
673
//
671
674
// State / sessions
672
675
//
Original file line number Diff line number Diff line change @@ -18213,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
18213
18213
18214
18214
// apply K-shift if needed
18215
18215
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
18216
- if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
18216
+ if (!llama_kv_cache_can_shift(& lctx)) {
18217
18217
GGML_ABORT("Deepseek2 does not support K-shift");
18218
18218
}
18219
18219
@@ -20462,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
20462
20462
llama_kv_cache_update_internal(*ctx);
20463
20463
}
20464
20464
20465
+ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
20466
+ return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
20467
+ }
20468
+
20465
20469
// deprecated
20466
20470
size_t llama_get_state_size(struct llama_context * ctx) {
20467
20471
return llama_state_get_size(ctx);
You can’t perform that action at this time.
0 commit comments