Skip to content

Commit d18a79e

Browse files
committed
llama_batch_ext_init with ctx
1 parent 1434c2c commit d18a79e

File tree

41 files changed

+124
-113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+124
-113
lines changed

common/common.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10161016
}
10171017

10181018
if (llama_model_has_encoder(model)) {
1019-
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), tmp.size(), 0, 0, true);
1019+
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), tmp.size(), 0, 0, true);
10201020
llama_encode_ext(lctx, batch.get());
10211021
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
10221022
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
@@ -1026,7 +1026,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10261026
tmp.push_back(decoder_start_token_id);
10271027
}
10281028
if (llama_model_has_decoder(model)) {
1029-
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
1029+
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
10301030
llama_decode_ext(lctx, batch.get());
10311031
}
10321032
llama_kv_self_clear(lctx);

common/speculative.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
2323
auto * result = new common_speculative {
2424
/* .ctx = */ ctx_dft,
2525
/* .smpl = */ nullptr,
26-
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
26+
/* .batch = */ llama_batch_ext_ptr(ctx_dft),
2727
/* .prompt = */ {},
2828
};
2929

examples/batched-bench/batched-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ int main(int argc, char ** argv) {
5959

6060
const int32_t n_kv_max = llama_n_ctx(ctx);
6161

62-
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
62+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
6363

6464
// decode in batches of ctx_params.n_batch tokens
6565
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {

examples/batched/batched.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
102102

103103
// create a llama_batch
104104
// we use this object to submit token data for decoding
105-
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
105+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
106106

107107
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
108108
for (int32_t i = 0; i < n_parallel; ++i) {

examples/cvector-generator/cvector-generator.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345345
llama_kv_self_clear(ctx);
346-
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
346+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true);
347347
if (llama_decode_ext(ctx, batch.get())) {
348348
fprintf(stderr, "%s : failed to eval\n", __func__);
349349
return false;

examples/embedding/embedding.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
167167

168168
// initialize batch
169169
const int n_prompts = prompts.size();
170-
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
170+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
171171

172172
// count number of embeddings
173173
int n_embd_count = 0;

examples/eval-callback/eval-callback.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) {
134134

135135
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
136136

137-
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
137+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true);
138138
if (llama_decode_ext(ctx, batch.get())) {
139139
LOG_ERR("%s : failed to eval\n", __func__);
140140
return false;

examples/gritlm/gritlm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1414
const llama_model * model = llama_get_model(ctx);
1515
const llama_vocab * vocab = llama_model_get_vocab(model);
1616

17-
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
17+
llama_batch_ext_ptr batch(ctx);
1818

1919
for (uint64_t i = 0; i < sentences.size(); i++) {
2020
batch.clear();
@@ -105,7 +105,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
105105
llama_set_embeddings(ctx, false);
106106
llama_set_causal_attn(ctx, true);
107107

108-
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
108+
llama_batch_ext_ptr batch(ctx);
109109

110110
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
111111
int32_t i_current_token = 0;

examples/imatrix/imatrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
497497
// clear the KV cache
498498
llama_kv_self_clear(ctx);
499499

500-
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
500+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
501501

502502
for (int j = 0; j < num_batches; ++j) {
503503
const int batch_start = start + j * n_batch;

examples/infill/infill.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ int main(int argc, char ** argv) {
353353

354354
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
355355

356-
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
356+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true);
357357
if (llama_decode_ext(ctx, batch.get())) {
358358
LOG_ERR("%s : failed to eval\n", __func__);
359359
return 1;

examples/llama-bench/llama-bench.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
14441444
for (int i = 1; i < n_tokens; i++) {
14451445
tokens[i] = std::rand() % n_vocab;
14461446
}
1447-
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true);
1447+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), n_tokens, n_past + n_processed, 0, true);
14481448
llama_decode_ext(ctx, batch.get());
14491449
n_processed += n_tokens;
14501450
}
@@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
14621462
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
14631463

14641464
for (int i = 0; i < n_gen; i++) {
1465-
auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true);
1465+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &token, 1, n_past + i, 0, true);
14661466
llama_decode_ext(ctx, batch.get());
14671467
llama_synchronize(ctx);
14681468
token = std::rand() % n_vocab;

examples/llama.android/llama/src/main/cpp/llama-android.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
273273

274274
extern "C"
275275
JNIEXPORT jlong JNICALL
276-
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
277-
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max);
276+
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jlong context_pointer) {
277+
const auto context = reinterpret_cast<llama_context *>(context_pointer);
278+
llama_batch_ext * batch = llama_batch_ext_init(context);
278279

279280
return reinterpret_cast<jlong>(batch);
280281
}

examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class LLamaAndroid {
4545
private external fun free_context(context: Long)
4646
private external fun backend_init(numa: Boolean)
4747
private external fun backend_free()
48-
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
48+
private external fun new_batch(context: Long): Long
4949
private external fun free_batch(batch: Long)
5050
private external fun new_sampler(): Long
5151
private external fun free_sampler(sampler: Long)
@@ -102,7 +102,7 @@ class LLamaAndroid {
102102
val context = new_context(model)
103103
if (context == 0L) throw IllegalStateException("new_context() failed")
104104

105-
val batch = new_batch(512, 0, 1)
105+
val batch = new_batch(context)
106106
if (batch == 0L) throw IllegalStateException("new_batch() failed")
107107

108108
val sampler = new_sampler()

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ actor LlamaContext {
2626
self.model = model
2727
self.context = context
2828
self.tokens_list = []
29-
self.batch = llama_batch_ext_init(512, 1)
29+
self.batch = llama_batch_ext_init(context)
3030
self.temporary_invalid_cchars = []
3131
let sparams = llama_sampler_chain_default_params()
3232
self.sampling = llama_sampler_chain_init(sparams)

examples/llava/gemma3-cli.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct gemma3_context {
7474
lctx = llama_init.context.get();
7575
vocab = llama_model_get_vocab(model);
7676
n_threads = params.cpuparams.n_threads;
77-
batch.reset(llama_batch_ext_init(params.n_batch, 1));
77+
batch.reset(llama_batch_ext_init(lctx));
7878
init_clip_model(params);
7979
}
8080

@@ -147,7 +147,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
147147
int64_t t1 = ggml_time_ms();
148148
eval_text(ctx, "<start_of_image>");
149149
llama_set_causal_attn(ctx.lctx, false);
150-
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
150+
llama_batch_ext_ptr batch_img = llama_batch_ext_ptr::init_from_embd(
151+
ctx.lctx, image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0);
151152
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
152153
LOG_ERR("failed to decode image\n");
153154
return 1;

examples/llava/llava-cli.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
2020
if (n_eval > n_batch) {
2121
n_eval = n_batch;
2222
}
23-
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
23+
auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true);
2424
if (llama_decode_ext(ctx_llama, batch.get())) {
2525
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
2626
return false;

examples/llava/llava.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
448448
n_eval = n_batch;
449449
}
450450
float * embd = image_embed->embed+i*n_embd;
451-
auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0);
451+
auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, embd, n_eval, n_embd, 0, 0);
452452
if (llama_decode_ext(ctx_llama, batch.get())) {
453453
LOG_ERR("%s : failed to eval\n", __func__);
454454
return false;

examples/llava/minicpmv-cli.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
101101
if (n_eval > n_batch) {
102102
n_eval = n_batch;
103103
}
104-
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
104+
auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true);
105105
if (llama_decode_ext(ctx_llama, batch.get())) {
106106
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
107107
return false;

examples/llava/qwen2vl-cli.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
6767
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
6868

6969
float * batch_embd = image_embed->embed+i*n_embd;
70-
auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0);
71-
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval * 4);
70+
const llama_pos * pos = batch_mrope_pos.data();
71+
auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, batch_embd, n_eval, n_embd, pos, 0);
7272

7373
if (llama_decode_ext(ctx_llama, batch.get())) {
7474
LOG_ERR("%s : failed to eval\n", __func__);
@@ -97,12 +97,11 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
9797
pos[j] = *st_pos_id + (j % n_eval);
9898
}
9999

100-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
100+
llama_batch_ext_ptr batch(ctx_llama);
101101
for (int j = 0; j < n_eval; j++) {
102102
llama_token token = tokens[i + j];
103-
batch.add_text(token, 0, 0, false); // position is set in the next step
103+
batch.add_text(token, *st_pos_id + i + j, 0, false);
104104
}
105-
llama_batch_ext_set_pos(batch.get(), pos.data(), pos.size());
106105
llama_batch_ext_set_output_last(batch.get());
107106

108107
if (llama_decode_ext(ctx_llama, batch.get())) {

examples/lookahead/lookahead.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
9292
const auto t_enc_start = ggml_time_us();
9393

9494
// eval the prompt
95-
auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
96-
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
95+
auto batch0 = llama_batch_ext_ptr::init_from_text(ctx, inp.data(), n_input - 1, 0, 0, true);
96+
auto batch1 = llama_batch_ext_ptr::init_from_text(ctx, &inp.back(), 1, n_input - 1, 0, true);
9797
llama_decode_ext(ctx, batch0.get());
9898
llama_decode_ext(ctx, batch1.get());
9999

@@ -117,7 +117,7 @@ int main(int argc, char ** argv) {
117117
// seq_id == 0 : the current input token
118118
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
119119
// seq_id [W + 1, W + G] : verification n-grams
120-
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
120+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
121121

122122
// target model sampling context
123123
struct common_sampler * smpl = common_sampler_init(model, params.sampling);

examples/lookup/lookup.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ int main(int argc, char ** argv){
9292

9393
const auto t_enc_start = ggml_time_us();
9494

95-
auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true);
96-
auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true);
95+
auto batch0 = llama_batch_ext_ptr::init_from_text(ctx, inp.data(), n_input - 1, 0, 0, true);
96+
auto batch1 = llama_batch_ext_ptr::init_from_text(ctx, &inp.back(), 1, n_input - 1, 0, true);
9797
llama_decode_ext(ctx, batch0.get());
9898
llama_decode_ext(ctx, batch1.get());
9999

@@ -111,7 +111,7 @@ int main(int argc, char ** argv){
111111

112112
std::vector<llama_token> draft;
113113

114-
llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(params.n_ctx, 1));
114+
llama_batch_ext_ptr batch_tgt(ctx);
115115

116116
// debug
117117
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);

examples/main/main.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ int main(int argc, char ** argv) {
548548
int enc_input_size = embd_inp.size();
549549
llama_token * enc_input_buf = embd_inp.data();
550550

551-
auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true);
551+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, enc_input_buf, enc_input_size, 0, 0, true);
552552
if (llama_decode_ext(ctx, batch.get())) {
553553
LOG_ERR("%s : failed to eval\n", __func__);
554554
return 1;
@@ -669,7 +669,7 @@ int main(int argc, char ** argv) {
669669

670670
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
671671

672-
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
672+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true);
673673
if (llama_decode_ext(ctx, batch.get())) {
674674
LOG_ERR("%s : failed to eval\n", __func__);
675675
return 1;

examples/parallel/parallel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
175175

176176
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
177177
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
178-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 1));
178+
llama_batch_ext_ptr batch(ctx);
179179

180180
int32_t n_total_prompt = 0;
181181
int32_t n_total_gen = 0;

examples/passkey/passkey.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ int main(int argc, char ** argv) {
123123
LOG_INF("prompt tokens: %d\n", n_tokens_all);
124124
//LOG_INF("prompt: %s\n", params.prompt.c_str());
125125

126-
llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
126+
llama_batch_ext_ptr batch(ctx);
127127

128128
int n_past = 0;
129129

examples/perplexity/perplexity.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
363363
// clear the KV cache
364364
llama_kv_self_clear(ctx);
365365

366-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
366+
llama_batch_ext_ptr batch(ctx);
367367

368368
for (int j = 0; j < num_batches; ++j) {
369369
const int batch_start = start + j * n_batch;
@@ -501,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
501501
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
502502
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
503503

504-
llama_batch_ext_ptr batch(llama_batch_ext_init(std::min(n_batch, n_ctx*n_seq), 1));
504+
llama_batch_ext_ptr batch(ctx);
505505

506506
std::vector<float> logits;
507507
if (num_batches > 1) {
@@ -830,7 +830,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
830830
const int max_tasks_per_batch = 32;
831831
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
832832

833-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 4));
833+
llama_batch_ext_ptr batch(ctx);
834834

835835
std::vector<float> tok_logits(n_vocab);
836836
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1112,7 +1112,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11121112
const int max_tasks_per_batch = 128;
11131113
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
11141114

1115-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 2));
1115+
llama_batch_ext_ptr batch(ctx);
11161116

11171117
std::vector<float> tok_logits(n_vocab);
11181118
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1465,7 +1465,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
14651465
const int max_tasks_per_batch = 32;
14661466
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
14671467

1468-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, max_seq));
1468+
llama_batch_ext_ptr batch(ctx);
14691469

14701470
std::vector<float> tok_logits(n_vocab);
14711471
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
@@ -1730,7 +1730,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17301730
// clear the KV cache
17311731
llama_kv_self_clear(ctx);
17321732

1733-
llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
1733+
llama_batch_ext_ptr batch(ctx);
17341734

17351735
for (int j = 0; j < num_batches; ++j) {
17361736
const int batch_start = start + j * n_batch;

examples/retrieval/retrieval.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
213213

214214
// initialize batch
215215
const int n_chunks = chunks.size();
216-
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
216+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
217217

218218
// allocate output
219219
const int n_embd = llama_model_n_embd(model);
@@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
253253
chunks[i].tokens.clear();
254254
}
255255

256-
llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1);
256+
llama_batch_ext * query_batch = llama_batch_ext_init(ctx);
257257

258258
// start loop, receive query and return top k similar chunks based on cosine similarity
259259
std::string query;

0 commit comments

Comments
 (0)