Skip to content

Commit 61a66f2

Browse files
committed
llama : improve infill support
ggml-ci
1 parent 3dc48fe commit 61a66f2

File tree

10 files changed

+563
-424
lines changed

10 files changed

+563
-424
lines changed

common/arg.cpp

+111-137
Large diffs are not rendered by default.

common/common.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <algorithm>
1414
#include <cinttypes>
15+
#include <climits>
1516
#include <cmath>
1617
#include <codecvt>
1718
#include <cstdarg>
@@ -23,10 +24,10 @@
2324
#include <regex>
2425
#include <sstream>
2526
#include <string>
27+
#include <thread>
2628
#include <unordered_map>
2729
#include <unordered_set>
2830
#include <vector>
29-
#include <thread>
3031

3132
#if defined(__APPLE__) && defined(__MACH__)
3233
#include <sys/types.h>
@@ -400,6 +401,21 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
400401
// String utils
401402
//
402403

404+
std::string string_format(const char * fmt, ...) {
405+
va_list ap;
406+
va_list ap2;
407+
va_start(ap, fmt);
408+
va_copy(ap2, ap);
409+
int size = vsnprintf(NULL, 0, fmt, ap);
410+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
411+
std::vector<char> buf(size + 1);
412+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
413+
GGML_ASSERT(size2 == size);
414+
va_end(ap2);
415+
va_end(ap);
416+
return std::string(buf.data(), size);
417+
}
418+
403419
std::vector<std::string> string_split(std::string input, char separator) {
404420
std::vector<std::string> parts;
405421
size_t separator_pos = input.find(separator);

common/common.h

+16-3
Original file line numberDiff line numberDiff line change
@@ -352,15 +352,28 @@ void gpt_init();
352352

353353
std::string gpt_params_get_system_info(const gpt_params & params);
354354

355-
bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
356-
bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
357-
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
355+
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
356+
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
357+
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
358358
bool set_process_priority(enum ggml_sched_priority prio);
359359

360360
//
361361
// String utils
362362
//
363363

364+
#ifdef __GNUC__
365+
#ifdef __MINGW32__
366+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
367+
#else
368+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
369+
#endif
370+
#else
371+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
372+
#endif
373+
374+
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
375+
std::string string_format(const char * fmt, ...);
376+
364377
std::vector<std::string> string_split(std::string input, char separator);
365378

366379
std::string string_strip(const std::string & str);

examples/infill/infill.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ int main(int argc, char ** argv) {
205205
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
206206
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
207207

208-
GGML_ASSERT(llama_token_prefix(model) >= 0);
209-
GGML_ASSERT(llama_token_suffix(model) >= 0);
208+
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
209+
GGML_ASSERT(llama_token_fim_suf(model) >= 0);
210210

211-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
212-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
211+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
212+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
213213

214214
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
215215
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
@@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
218218
}
219219
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
220220

221-
const llama_token middle_token = llama_token_middle(model);
221+
const llama_token middle_token = llama_token_fim_mid(model);
222222
if (middle_token >= 0) {
223223
embd_inp.push_back(middle_token);
224224
}
@@ -508,8 +508,8 @@ int main(int argc, char ** argv) {
508508
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
509509
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
510510

511-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
512-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
511+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
512+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
513513

514514
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
515515
embd_end = params.spm_infill ? inp_pfx : inp_sfx;

examples/server/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
526526
- `input_prefix`: Set the prefix of the code to infill.
527527
- `input_suffix`: Set the suffix of the code to infill.
528528

529-
It also accepts all the options of `/completion` except `stream` and `prompt`.
529+
It also accepts all the options of `/completion`.
530530

531531
### **GET** `/props`: Get server global properties.
532532

examples/server/server.cpp

+84-75
Original file line numberDiff line numberDiff line change
@@ -753,12 +753,7 @@ struct server_context {
753753
metrics.init();
754754
}
755755

756-
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
757-
// TODO: currently, we tokenize using special tokens by default
758-
// this is not always correct (see https://github.com./ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759-
// but it's better compared to completely ignoring ChatML and other chat templates
760-
const bool TMP_FORCE_SPECIAL = true;
761-
756+
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
762757
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
763758
// or the first element of the json_prompt array is a string.
764759
std::vector<llama_token> prompt_tokens;
@@ -771,10 +766,10 @@ struct server_context {
771766

772767
std::vector<llama_token> p;
773768
if (first) {
774-
p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
769+
p = ::llama_tokenize(ctx, s, add_special, parse_special);
775770
first = false;
776771
} else {
777-
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
772+
p = ::llama_tokenize(ctx, s, false, parse_special);
778773
}
779774

780775
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@@ -788,7 +783,7 @@ struct server_context {
788783
}
789784
} else {
790785
auto s = json_prompt.template get<std::string>();
791-
prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
786+
prompt_tokens = ::llama_tokenize(ctx, s, add_special, parse_special);
792787
}
793788

794789
return prompt_tokens;
@@ -1215,7 +1210,7 @@ struct server_context {
12151210
slot.params.n_predict, n_ctx_train);
12161211
}
12171212

1218-
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1213+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
12191214

12201215
return slot.has_next_token; // continue
12211216
}
@@ -1483,9 +1478,8 @@ struct server_context {
14831478
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
14841479
data["index"] = 0;
14851480
create_task(data, false, nullptr);
1486-
}
1487-
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
1488-
else if (prompt.is_array()) {
1481+
} else if (prompt.is_array()) {
1482+
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
14891483
std::vector<json> prompts = prompt;
14901484
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
14911485
// prompts[0] is the question
@@ -1510,9 +1504,8 @@ struct server_context {
15101504
}
15111505
}
15121506
}
1513-
}
1514-
// invalid case
1515-
else {
1507+
} else {
1508+
// invalid case
15161509
throw std::runtime_error(error_msg);
15171510
}
15181511

@@ -1971,70 +1964,69 @@ struct server_context {
19711964
slot.t_start_process_prompt = ggml_time_us();
19721965
slot.t_start_generation = 0;
19731966

1974-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1975-
const bool add_bos = llama_add_bos_token(model);
1976-
bool suff_rm_leading_spc = true;
1977-
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1978-
params.input_suffix.erase(0, 1);
1979-
suff_rm_leading_spc = false;
1980-
}
1981-
1982-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1983-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1984-
1985-
const int space_token = 29871; // TODO: this should not be hardcoded
1986-
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
1987-
suffix_tokens.erase(suffix_tokens.begin());
1988-
}
1989-
1990-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
1991-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
1992-
1993-
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
1994-
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
1995-
if (add_bos) {
1996-
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
1997-
}
1998-
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
1999-
2000-
const llama_token middle_token = llama_token_middle(model);
2001-
if (middle_token >= 0) {
2002-
embd_inp.push_back(middle_token);
2003-
}
2004-
2005-
prompt_tokens = embd_inp;
2006-
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2007-
// require slot.prompt to be array of 2 strings
2008-
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
2009-
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
2010-
slot.release();
2011-
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
2012-
continue;
2013-
}
2014-
2015-
// prompt: [BOS]query[EOS][SEP]doc[EOS]
2016-
prompt_tokens.clear();
2017-
prompt_tokens.push_back(llama_token_bos(model));
2018-
{
2019-
const auto part = tokenize(slot.prompt[0], false);
2020-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2021-
}
2022-
prompt_tokens.push_back(llama_token_eos(model));
2023-
prompt_tokens.push_back(llama_token_sep(model));
2024-
{
2025-
const auto part = tokenize(slot.prompt[1], false);
2026-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2027-
}
2028-
prompt_tokens.push_back(llama_token_eos(model));
2029-
} else {
2030-
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
1967+
switch (slot.cmpl_type) {
1968+
case SERVER_TASK_CMPL_TYPE_NORMAL:
1969+
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
1970+
{
1971+
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
1972+
} break;
1973+
case SERVER_TASK_CMPL_TYPE_RERANK:
1974+
{
1975+
// require slot.prompt to be array of 2 strings
1976+
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
1977+
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
1978+
slot.release();
1979+
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
1980+
continue;
1981+
}
1982+
1983+
// prompt: [BOS]query[EOS][SEP]doc[EOS]
1984+
prompt_tokens.clear();
1985+
prompt_tokens.push_back(llama_token_bos(model));
1986+
{
1987+
const auto part = tokenize(slot.prompt[0], false, false);
1988+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1989+
}
1990+
prompt_tokens.push_back(llama_token_eos(model));
1991+
prompt_tokens.push_back(llama_token_sep(model));
1992+
{
1993+
const auto part = tokenize(slot.prompt[1], false, false);
1994+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1995+
}
1996+
prompt_tokens.push_back(llama_token_eos(model));
1997+
} break;
1998+
case SERVER_TASK_CMPL_TYPE_INFILL:
1999+
{
2000+
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
2001+
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
2002+
2003+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
2004+
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
2005+
2006+
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2007+
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2008+
2009+
if (llama_add_bos_token(model)) {
2010+
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
2011+
}
2012+
2013+
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
2014+
embd_inp.push_back(llama_token_fim_mid(model));
2015+
2016+
prompt_tokens = std::move(embd_inp);
2017+
} break;
20312018
}
20322019

20332020
slot.n_past = 0;
20342021
slot.n_prompt_tokens = prompt_tokens.size();
20352022

20362023
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
20372024

2025+
// print prompt tokens:
2026+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
2027+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
2028+
}
2029+
20382030
// empty prompt passed -> release the slot and send empty response
20392031
if (prompt_tokens.empty()) {
20402032
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -2924,7 +2916,23 @@ int main(int argc, char ** argv) {
29242916
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
29252917
};
29262918

2927-
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2919+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2920+
std::string err;
2921+
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
2922+
err += "prefix token is missing. ";
2923+
}
2924+
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
2925+
err += "suffix token is missing. ";
2926+
}
2927+
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
2928+
err += "middle token is missing. ";
2929+
}
2930+
2931+
if (!err.empty()) {
2932+
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
2933+
return;
2934+
}
2935+
29282936
json data = json::parse(req.body);
29292937
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
29302938
};
@@ -3010,7 +3018,8 @@ int main(int argc, char ** argv) {
30103018
if (body.count("content") != 0) {
30113019
const bool add_special = json_value(body, "add_special", false);
30123020
const bool with_pieces = json_value(body, "with_pieces", false);
3013-
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
3021+
3022+
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
30143023

30153024
if (with_pieces) {
30163025
for (const auto& token : tokens) {

include/llama.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ extern "C" {
896896
// Special tokens
897897
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
898898
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
899+
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
899900
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
900901
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
901902
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@@ -904,11 +905,17 @@ extern "C" {
904905
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
905906
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
906907

907-
// Codellama infill tokens
908-
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
909-
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
910-
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
911-
LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
908+
// infill tokens
909+
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
910+
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
911+
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
912+
913+
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
914+
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
915+
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
916+
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
917+
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
918+
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
912919

913920
//
914921
// Tokenization

0 commit comments

Comments
 (0)