Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add token healing example #7028

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 25 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true;
}
if (arg == "-th" || arg == "--token-healing") {
if (++i >= argc) {
invalid_param = true;
return true;
}
sparams.token_healing_enabled = true;
auto & th_type = sparams.token_healing_type;
auto & th_n_rollback = sparams.token_healing_n_rollback;
std::string value(argv[i]);
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
sparams.token_healing_enabled = false;
}
} else { invalid_param = true; }
return true;
}
if (arg == "--override-kv") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -j SCHEMA, --json-schema SCHEMA\n");
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n");
printf(" Token healing type. (default: 0, disabled)\n");
printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n");
printf(" --cfg-negative-prompt PROMPT\n");
printf(" negative prompt to use for guidance. (default: empty)\n");
printf(" --cfg-negative-prompt-file FNAME\n");
Expand Down
145 changes: 142 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,109 @@
#include "sampling.h"
#include <random>

//
// Token healing (internal)
//

static bool startswith(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) != std::string::npos;
}

static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
return true;
}
}
return false;
}

static std::vector<llama_token> token_healing_find_prefix(
const llama_context * ctx_main,
const std::string & prefix,
const bool include_partial_prefix) {
// Example: prefix=" world" -> " world", " worldwide", ...
// If `include_partial_prefix`, include also: " w", " wo", ...
std::vector<llama_token> candidates;
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix) ||
(include_partial_prefix && startswith(prefix, token))) {
candidates.push_back(token_id);
}
}
return candidates;
}

//
// Token healing (external)
//

std::string llama_token_healing_prepare(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove,
int * n_removed) {
if (n_removed != nullptr) {
*n_removed = 0;
}
if (tokens.empty()) {
return "";
}

const llama_model * model = llama_get_model(ctx_main);
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens.size();
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx);
int removed = 0;
std::string prefix;
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
// and stop early if a special token is encountered
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
break;
}
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
break;
}
removed += 1;
prefix = new_prefix;
}
if (removed == 0) { // E.g. if the last token is a special token
return "";
}
// If constrained decoding would give back the original prompt, there is no need to modify the context
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
if (removed == 1 && candidates.size() == 1) {
LOG("token_healing: nothing to heal\n");
return "";
}
// Finalize outputs
if (n_removed != nullptr) {
*n_removed = removed;
}
tokens.resize(n_ctx - removed);
return prefix;
}

void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
ctx_sampling->token_healing_prefix = prefix;
}

//
// Sampling
//

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

Expand Down Expand Up @@ -62,6 +165,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}

ctx->token_healing_prefix.clear();

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
}
Expand Down Expand Up @@ -119,7 +224,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
}

std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
std::string result = "(Token healing) -> CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
Expand Down Expand Up @@ -297,8 +402,27 @@ static llama_token_data_array llama_sampling_prepare_impl(

cur.clear();

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
// Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_type = params.token_healing_type;
const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing_enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);

LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
for (const llama_token token_id : th_candidates) {
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
}

// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
for (const llama_token token_id: th_candidates) {
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
}

llama_token_data_array cur_p = { cur.data(), cur.size(), false };
Expand Down Expand Up @@ -361,4 +485,19 @@ void llama_sampling_accept(
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}

if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
if (new_token_piece.size() < th_prefix.size()) {
// Shift prefix constraint (for multi step token healing)
th_prefix = th_prefix.substr(new_token_piece.size());
} else {
// Prefix has been generated => no more constrained generation
th_prefix.clear();
LOG("token_healing: done\n");
}
}
}
}
28 changes: 28 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
TEMPERATURE = 't'
};

enum class llama_token_healing_type : uint8_t {
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};

// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
Expand Down Expand Up @@ -62,6 +69,10 @@ typedef struct llama_sampling_params {

std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;

llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
bool token_healing_enabled = false;
int token_healing_n_rollback = -1; // number of tokens to roll back
} llama_sampling_params;

// general sampler context
Expand All @@ -78,6 +89,8 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;

std::string token_healing_prefix;

// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
Expand Down Expand Up @@ -152,3 +165,18 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);

//
// Token healing
//

// Roll back `tokens` for constrained generation according to the token healing
// strategy. Returns the prefix for constrained generation.
std::string llama_token_healing_prepare(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove = -1,
int * n_removed = nullptr);

void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ else()
add_subdirectory(retrieval)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(simple-token-healing)
add_subdirectory(passkey)
add_subdirectory(speculative)
add_subdirectory(lookahead)
Expand Down
44 changes: 41 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ int main(int argc, char ** argv) {
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());

if (sparams.token_healing_enabled && (params.instruct || params.chatml || !params.input_suffix.empty())) {
sparams.token_healing_enabled = false;
LOG("token_healing: disabled due to custom suffix");
}
std::string token_healing_prefix;
int token_healing_n_removed = 0;
if (!params.interactive_first && sparams.token_healing_enabled) {
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
sparams.token_healing_n_rollback, &token_healing_n_removed);
}

// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model));
Expand All @@ -283,7 +294,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());

original_prompt_len = original_inp.size();
original_prompt_len = original_inp.size() - token_healing_n_removed;
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
LOG("guidance_offset: %s", log_tostr(guidance_offset));
Expand Down Expand Up @@ -499,6 +510,7 @@ int main(int argc, char ** argv) {
int n_consumed = 0;
int n_session_consumed = 0;
int n_past_guidance = 0;
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix

std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
Expand All @@ -520,6 +532,7 @@ int main(int argc, char ** argv) {
}

struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
Expand Down Expand Up @@ -734,7 +747,16 @@ int main(int argc, char ** argv) {
if (input_echo && display) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());

// Suppress printing while generating token healing prefix (only for interactive mode; kinda hacky...)
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
printf("%s", token_str.substr(n_bytes_to_skip).c_str());
n_bytes_to_skip = 0;
} else if (n_bytes_to_skip > 0) {
n_bytes_to_skip -= token_str.size();
} else {
printf("%s", token_str.c_str());
}

if (embd.size() > 1) {
input_tokens.push_back(id);
Expand Down Expand Up @@ -813,6 +835,7 @@ int main(int argc, char ** argv) {
}
}

token_healing_n_removed = 0;
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");

Expand Down Expand Up @@ -896,13 +919,24 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
}

if (sparams.token_healing_enabled) {
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
const int n_new_tokens = embd_inp.size() - original_size;
const int max_to_remove = sparams.token_healing_n_rollback < 0
? n_new_tokens
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
max_to_remove, &token_healing_n_removed);
n_bytes_to_skip = token_healing_prefix.size();
}

for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
output_ss << llama_token_to_piece(ctx, token);
}

n_remain -= line_inp.size();
n_remain -= line_inp.size() + token_healing_n_removed;
LOG("n_remain: %d\n", n_remain);
} else {
LOG("empty line, passing control back\n");
Expand All @@ -914,6 +948,10 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(ctx_sampling);
if (token_healing_n_removed > 0) {
// Set new prefix after an interaction
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
}
}
is_interacting = false;
}
Expand Down
11 changes: 11 additions & 0 deletions examples/simple-token-healing/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
set(TARGET simple-token-healing)
add_executable(${TARGET} simple-token-healing.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

set(TARGET simple-token-healing-1)
add_executable(${TARGET} simple-token-healing-1.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)