@@ -43,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
43
43
// Put here anything you want back in beam_search_callback().
44
44
struct beam_search_callback_state {
45
45
llama_context* ctx;
46
- std::vector<llama_token>* response;
46
+ std::vector<llama_token> response;
47
47
};
48
48
49
49
bool is_at_eos (beam_search_callback_state, llama_token const * tokens, size_t const n_tokens) {
@@ -56,28 +56,28 @@ bool is_at_eos(beam_search_callback_state, llama_token const* tokens, size_t con
56
56
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
57
57
// This is also called when the stop condition is met.
58
58
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
59
- void beam_search_callback (void * callback_state , llama_beams_state beams_state) {
60
- auto const state = *static_cast <beam_search_callback_state*>(callback_state );
59
+ void beam_search_callback (void * callback_state_ptr , llama_beams_state beams_state) {
60
+ auto & callback_state = *static_cast <beam_search_callback_state*>(callback_state_ptr );
61
61
// Mark beams as EOS as needed.
62
62
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
63
63
llama_beam_view& beam_view = beams_state.beam_views [i];
64
- if (!beam_view.eos && is_at_eos (state , beam_view.tokens , beam_view.n_tokens )) {
64
+ if (!beam_view.eos && is_at_eos (callback_state , beam_view.tokens , beam_view.n_tokens )) {
65
65
beam_view.eos = true ;
66
66
}
67
67
}
68
68
printf (" ," ); // Show progress
69
69
if (size_t const n = beams_state.common_prefix_length ) {
70
- state .response -> resize (state .response -> size () + n);
70
+ callback_state .response . resize (callback_state .response . size () + n);
71
71
assert (0u < beams_state.n_beams );
72
72
llama_token const * tokens = beams_state.beam_views [0 ].tokens ;
73
- std::copy (tokens, tokens + n, state .response -> end () - n);
73
+ std::copy (tokens, tokens + n, callback_state .response . end () - n);
74
74
printf (" %lu" , n);
75
75
}
76
76
fflush (stdout);
77
77
#if 1 // DEBUG: print current beams for this iteration
78
- std::cout << " \n\n Current beams:\n " ;
78
+ std::cout << " \n\n Current beams (last_call= " << beams_state. last_call << " ) :\n " ;
79
79
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
80
- std::cout << " beams[" <<i<<" ]: " << ostream_beam_view{state .ctx ,beams_state.beam_views [i]} << std::endl;
80
+ std::cout << " beams[" <<i<<" ]: " << ostream_beam_view{callback_state .ctx ,beams_state.beam_views [i]} << std::endl;
81
81
}
82
82
#endif
83
83
}
@@ -168,14 +168,13 @@ int main(int argc, char ** argv)
168
168
}
169
169
n_past += tokens_list.size ();
170
170
171
- std::vector<llama_token> response;
172
- beam_search_callback_state callback_state{ctx, &response};
171
+ beam_search_callback_state callback_state{ctx, {}};
173
172
size_t const beam_width = static_cast <size_t >(params.n_beams );
174
173
int const n_predict = 256 ;
175
174
llama_beam_search (ctx, beam_search_callback, &callback_state, beam_width, n_past, n_predict, params.n_threads );
176
175
177
176
printf (" \n\n " );
178
- for (llama_token const token_id : response) {
177
+ for (llama_token const token_id : callback_state. response ) {
179
178
printf (" %s" , llama_token_to_str (ctx,token_id));
180
179
}
181
180
printf (" \n " );
0 commit comments