@@ -363,7 +363,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
363
363
// clear the KV cache
364
364
llama_kv_self_clear (ctx);
365
365
366
- llama_batch_ext_ptr batch (llama_batch_ext_init (n_batch, 1 ) );
366
+ llama_batch_ext_ptr batch (ctx );
367
367
368
368
for (int j = 0 ; j < num_batches; ++j) {
369
369
const int batch_start = start + j * n_batch;
@@ -501,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
501
501
GGML_ASSERT (n_batch < n_ctx || n_batch % n_ctx == 0 );
502
502
GGML_ASSERT (params.n_ctx == n_seq * n_ctx);
503
503
504
- llama_batch_ext_ptr batch (llama_batch_ext_init ( std::min (n_batch, n_ctx*n_seq), 1 ) );
504
+ llama_batch_ext_ptr batch (ctx );
505
505
506
506
std::vector<float > logits;
507
507
if (num_batches > 1 ) {
@@ -830,7 +830,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
830
830
const int max_tasks_per_batch = 32 ;
831
831
const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
832
832
833
- llama_batch_ext_ptr batch (llama_batch_ext_init (n_ctx, 4 ) );
833
+ llama_batch_ext_ptr batch (ctx );
834
834
835
835
std::vector<float > tok_logits (n_vocab);
836
836
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1112,7 +1112,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
1112
1112
const int max_tasks_per_batch = 128 ;
1113
1113
const int max_seq = std::min (2 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
1114
1114
1115
- llama_batch_ext_ptr batch (llama_batch_ext_init (n_ctx, 2 ) );
1115
+ llama_batch_ext_ptr batch (ctx );
1116
1116
1117
1117
std::vector<float > tok_logits (n_vocab);
1118
1118
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1465,7 +1465,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1465
1465
const int max_tasks_per_batch = 32 ;
1466
1466
const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
1467
1467
1468
- llama_batch_ext_ptr batch (llama_batch_ext_init (n_ctx, max_seq) );
1468
+ llama_batch_ext_ptr batch (ctx );
1469
1469
1470
1470
std::vector<float > tok_logits (n_vocab);
1471
1471
std::vector<float > batch_logits (size_t (n_ctx)*n_vocab);
@@ -1730,7 +1730,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
1730
1730
// clear the KV cache
1731
1731
llama_kv_self_clear (ctx);
1732
1732
1733
- llama_batch_ext_ptr batch (llama_batch_ext_init (n_batch, 1 ) );
1733
+ llama_batch_ext_ptr batch (ctx );
1734
1734
1735
1735
for (int j = 0 ; j < num_batches; ++j) {
1736
1736
const int batch_start = start + j * n_batch;
0 commit comments