Skip to content

Commit b989365

Browse files
committed
beam_search_callback_state struct can be heavy and referenced in callback.
1 parent 40c9403 commit b989365

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

examples/beam_search/beam_search.cpp

+10-11
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
4343
// Put here anything you want back in beam_search_callback().
4444
struct beam_search_callback_state {
4545
llama_context* ctx;
46-
std::vector<llama_token>* response;
46+
std::vector<llama_token> response;
4747
};
4848

4949
bool is_at_eos(beam_search_callback_state, llama_token const* tokens, size_t const n_tokens) {
@@ -56,28 +56,28 @@ bool is_at_eos(beam_search_callback_state, llama_token const* tokens, size_t con
5656
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
5757
// This is also called when the stop condition is met.
5858
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
59-
void beam_search_callback(void* callback_state, llama_beams_state beams_state) {
60-
auto const state = *static_cast<beam_search_callback_state*>(callback_state);
59+
void beam_search_callback(void* callback_state_ptr, llama_beams_state beams_state) {
60+
auto& callback_state = *static_cast<beam_search_callback_state*>(callback_state_ptr);
6161
// Mark beams as EOS as needed.
6262
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
6363
llama_beam_view& beam_view = beams_state.beam_views[i];
64-
if (!beam_view.eos && is_at_eos(state, beam_view.tokens, beam_view.n_tokens)) {
64+
if (!beam_view.eos && is_at_eos(callback_state, beam_view.tokens, beam_view.n_tokens)) {
6565
beam_view.eos = true;
6666
}
6767
}
6868
printf(","); // Show progress
6969
if (size_t const n = beams_state.common_prefix_length) {
70-
state.response->resize(state.response->size() + n);
70+
callback_state.response.resize(callback_state.response.size() + n);
7171
assert(0u < beams_state.n_beams);
7272
llama_token const* tokens = beams_state.beam_views[0].tokens;
73-
std::copy(tokens, tokens + n, state.response->end() - n);
73+
std::copy(tokens, tokens + n, callback_state.response.end() - n);
7474
printf("%lu", n);
7575
}
7676
fflush(stdout);
7777
#if 1 // DEBUG: print current beams for this iteration
78-
std::cout << "\n\nCurrent beams:\n";
78+
std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
7979
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
80-
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
80+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_state.ctx,beams_state.beam_views[i]} << std::endl;
8181
}
8282
#endif
8383
}
@@ -168,14 +168,13 @@ int main(int argc, char ** argv)
168168
}
169169
n_past += tokens_list.size();
170170

171-
std::vector<llama_token> response;
172-
beam_search_callback_state callback_state{ctx, &response};
171+
beam_search_callback_state callback_state{ctx, {}};
173172
size_t const beam_width = static_cast<size_t>(params.n_beams);
174173
int const n_predict = 256;
175174
llama_beam_search(ctx, beam_search_callback, &callback_state, beam_width, n_past, n_predict, params.n_threads);
176175

177176
printf("\n\n");
178-
for (llama_token const token_id : response) {
177+
for (llama_token const token_id : callback_state.response) {
179178
printf("%s", llama_token_to_str(ctx,token_id));
180179
}
181180
printf("\n");

0 commit comments

Comments
 (0)