-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Serving] Prefix Cache #2295
[Serving] Prefix Cache #2295
Conversation
cpp/serve/request.h
Outdated
/*! | ||
* \brief The flag if request data is pinned in KVCache. Used for system prompt cache. | ||
*/ | ||
bool pinned = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to generation config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved and added test for pinned system prompt.
cpp/serve/prefix_cache.h
Outdated
/*! | ||
* \brief The matched result from prefix cache. | ||
*/ | ||
struct MatchedResult { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PrefixCacheMatchedResult, given this namespace is serve and MatchedResult is a bit generic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed
@@ -115,6 +121,9 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { | |||
// Speculative models shift left the input tokens by 1 when base model has committed tokens. | |||
// Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens. | |||
for (int j = 0; j < static_cast<int>(input_data.size()); ++j) { | |||
if (!model_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_id != 0 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, here does not matter if model == 0. And prefix cache only cares about. base model(model_id == 0), so we only update if model_id == 0. However, it will be changed if input data logic refactored soon, by using offset instead.
python/mlc_llm/serve/request.py
Outdated
@@ -33,6 +33,9 @@ class Request(Object): | |||
The JSON string of the default generation config. | |||
When a field in the input generation_config is not defined, | |||
we use the value in the default generation config. | |||
|
|||
pinned : bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us avoid exposing pinned for now to the user side
cpp/serve/engine.cc
Outdated
@@ -97,6 +97,10 @@ class EngineImpl : public Engine { | |||
return TResult::Error(engine_config_res.UnwrapErr()); | |||
} | |||
EngineConfig engine_config = engine_config_res.Unwrap(); | |||
n->estate_->prefix_cache = PrefixCache::Init( | |||
engine_config->max_total_sequence_length / engine_config->kv_cache_page_size, | |||
engine_config->kv_cache_page_size * 16, engine_config->max_num_sequence, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment where 16 comes from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agrree we should avoid magic numbers, have constant name
* tokens. If the request state entry is not added to KVCache yet, this method will add/fork the | ||
* request in the KVCache, depending on the matching result from prefix cache. | ||
* \param estate The engine state. | ||
* \param[out] input The prefill input to be matched and updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* \param[out] input The prefill input to be matched and updated. | |
* \param[in,out] input The prefill input to be matched and updated. |
there are also a few other places need updates
@@ -50,8 +64,25 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_ | |||
// So we mark the parent entry as finished. | |||
rstate->entries[parent_idx]->status = RequestStateStatus::kFinished; | |||
// Remove the request state entry from all the models. | |||
RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models); | |||
estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id); | |||
if (estate->prefix_cache->HasSequence(rstate->entries[parent_idx]->mstates[0]->internal_id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very similar to the changes in line 32, it's better extracting to a function or lambda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
making it a function with name
Are these ideas already considered? |
@kripper Thanks for suggestions, but I think these are not part of this PR. Although this PR does aim to improve multi-round chat, it is in a different way from AttentionStore. |
Ok. I will leave the feature request there. |
cpp/serve/engine.cc
Outdated
@@ -97,6 +97,10 @@ class EngineImpl : public Engine { | |||
return TResult::Error(engine_config_res.UnwrapErr()); | |||
} | |||
EngineConfig engine_config = engine_config_res.Unwrap(); | |||
n->estate_->prefix_cache = PrefixCache::Init( | |||
engine_config->max_total_sequence_length / engine_config->kv_cache_page_size, | |||
engine_config->kv_cache_page_size * 16, engine_config->max_num_sequence, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agrree we should avoid magic numbers, have constant name
cpp/serve/prefix_cache.h
Outdated
* \param lazy The flag if the sequence should be removed lazily or intermediary. | ||
* \throw Error if the given sequence id is not valid. | ||
*/ | ||
virtual void RecycleSequence(int64_t seq_id, PackedFunc callback, bool lazy = true) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a signature to callback, in this can TypedPackedFunc is better
cpp/serve/prefix_cache.h
Outdated
* \param sliding_window_size The sliding window size, -1 for disabled sliding window. | ||
* \param attention_sink_size The attention sink position for sliding window. | ||
*/ | ||
static PrefixCache Init(size_t num_pages, size_t page_size, size_t num_seqs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init => Create. In this case, perhaps we can change to constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming:
- num_seqs => max_num_seqs
- num_pages => max_num_pages
}), | ||
/*lazy=*/true); | ||
} | ||
// If the request is pinned, do nothing over the prefix cache and KVCache. Let the data be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not make data orphan, instead, add a new type of state(besides active etc, make the kind is SystemKeepAlive)
@@ -50,8 +64,25 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_ | |||
// So we mark the parent entry as finished. | |||
rstate->entries[parent_idx]->status = RequestStateStatus::kFinished; | |||
// Remove the request state entry from all the models. | |||
RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models); | |||
estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id); | |||
if (estate->prefix_cache->HasSequence(rstate->entries[parent_idx]->mstates[0]->internal_id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
making it a function with name
cpp/json_ffi/openai_api_protocol.h
Outdated
@@ -148,6 +148,7 @@ class ChatCompletionRequest { | |||
std::optional<std::string> user = std::nullopt; | |||
bool ignore_eos = false; | |||
// RequestResponseFormat response_format; //TODO: implement this | |||
bool pinned = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us not expose pinned for now in the JSON FFI as it is not necessary
cpp/json_ffi/json_ffi_engine.cc
Outdated
@@ -84,7 +84,7 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request | |||
request.top_logprobs, request.logit_bias, request.seed, | |||
request.ignore_eos, request.max_tokens, std::move(stop_strs), | |||
conv_template_.stop_token_ids, /*response_format=*/std::nullopt, | |||
this->default_generation_cfg_json_str_); | |||
request.pinned, this->default_generation_cfg_json_str_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
json FFI should not expose pinned for now
cpp/serve/config.h
Outdated
@@ -50,6 +50,7 @@ class GenerationConfigNode : public Object { | |||
std::vector<int> stop_token_ids; | |||
|
|||
ResponseFormat response_format; | |||
bool pinned = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us add a sub structure, DebugConfig, which contains debug related options that should not be exposed to the end point.
Debug config should include two fields:
- ignore_eos
- pin_system_prompt
python/mlc_llm/serve/config.py
Outdated
@@ -93,6 +93,11 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes | |||
|
|||
response_format : ResponseFormat | |||
The response format of the generation output. | |||
|
|||
pinned : bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a subdataclass debug_config here. to include
- pin_system_prompt
- ignore_eos
cpp/serve/prefix_cache.cc
Outdated
/*! | ||
* \brief The core data structure radix tree. | ||
*/ | ||
PagedRadixTree radix_tree; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code style: class member must ends with underscore, PagedRadixTree radix_tree_
, this is for better readability
cpp/serve/prefix_cache.h
Outdated
* \param sliding_window_size The sliding window size, -1 for disabled sliding window. | ||
* \param attention_sink_size The attention sink position for sliding window. | ||
*/ | ||
static PrefixCache Init(size_t num_pages, size_t page_size, size_t num_seqs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming:
- num_seqs => max_num_seqs
- num_pages => max_num_pages
NVTXScopedRange nvtx_scope("EngineAction postproc"); | ||
std::vector<RequestStateEntry> finished_rsentries; | ||
finished_rsentries.reserve(requests.size()); | ||
|
||
Array<RequestStreamOutput> callback_delta_outputs; | ||
callback_delta_outputs.reserve(requests.size()); | ||
|
||
for (Request request : requests) { | ||
RequestState rstate = estate->GetRequestState(request); | ||
for (const RequestStateEntry& rsentry : rstate->entries) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lift to a sub function
* \param estate The engine state. | ||
* \param[out] input The prefill input to be matched and updated. | ||
*/ | ||
void MatchPrefixCache(EngineState estate, PrefillInput& input) final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass by *
class PrefixCacheObj : public Object { | ||
public: | ||
/*! | ||
* \brief Insert a new tokenized sequence into Prefix Cache. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function updates the PrefixCache state, users must create a MatchAndReusePrefixCache function, that take the result and perform related update
cpp/serve/prefix_cache.h
Outdated
* \brief The parent sequence ID to fork in KVCache. The default value if -1, which means no | ||
* forking operation needed. | ||
*/ | ||
int64_t parent_seq_id = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fork_from_seq_id
cpp/serve/prefix_cache.h
Outdated
*/ | ||
int64_t parent_seq_id = -1; | ||
/*! | ||
* \brief The matched prefix offset, which should be skipped when prefilling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The matched prefix offset, which can be used to guide how to fork the parent seq
* \param estate The engine state. | ||
* \param[out] input The prefill input to be matched and updated. | ||
*/ | ||
void MatchPrefixCache(EngineState estate, PrefillInput& input) final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatchAndReusePrefixCahe
cpp/serve/prefix_cache.cc
Outdated
* \param sliding_window_size The sliding window size, -1 for disabled sliding window. | ||
* \param attention_sink_size The attention sink position for sliding window. | ||
*/ | ||
explicit PrefixCacheImpl(size_t num_pages, size_t page_size, size_t num_seqs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix_cahce_max_num_seqs
80f5d03
to
74fcee5
Compare
cpp/json_ffi/json_ffi_engine.cc
Outdated
@@ -78,13 +78,19 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request | |||
} | |||
} | |||
|
|||
bool pinned_system_prompt = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should pass in debug_config into serve::DebugConfig directly, instead of pas sby per value
cpp/serve/config.h
Outdated
@@ -50,6 +55,7 @@ class GenerationConfigNode : public Object { | |||
std::vector<int> stop_token_ids; | |||
|
|||
ResponseFormat response_format; | |||
DebugConfig debug_config = {false}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make DebugConfig optional
@@ -30,6 +30,11 @@ struct ResponseFormat { | |||
Optional<String> schema = NullOpt; | |||
}; | |||
|
|||
/*! \brief The debug configuration of a request. */ | |||
struct DebugConfig { | |||
bool pinned_system_prompt = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make DebugConfig an Object
cpp/serve/radix_tree.cc
Outdated
SequenceIDNode* raw_pool_; | ||
/*! \brief The sequence ID node pool. */ | ||
/*! \brief The size of each node pool block. */ | ||
static constexpr size_t NODE_BLOCK_SIZE_ = 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kNodeBlockSize
This PR introduces the prefix cache into serving engine, to manage prefix and accelerate prefill process.
cc: @tqchen @MasterJH5574 @Ubospica