Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] Prefix Cache #2295

Merged
merged 12 commits into from
May 21, 2024
Merged

[Serving] Prefix Cache #2295

merged 12 commits into from
May 21, 2024

Conversation

cyx-6
Copy link
Contributor

@cyx-6 cyx-6 commented May 7, 2024

This PR introduces the prefix cache into serving engine, to manage prefix and accelerate prefill process.

cc: @tqchen @MasterJH5574 @Ubospica

/*!
* \brief The flag if request data is pinned in KVCache. Used for system prompt cache.
*/
bool pinned = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to generation config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved and added test for pinned system prompt.

tqchen
tqchen previously requested changes May 14, 2024
cpp/serve/request_state.h Show resolved Hide resolved
/*!
* \brief The matched result from prefix cache.
*/
struct MatchedResult {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrefixCacheMatchedResult, given this namespace is serve and MatchedResult is a bit generic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed

cpp/serve/radix_tree.cc Show resolved Hide resolved
cpp/serve/radix_tree.cc Show resolved Hide resolved
cpp/serve/radix_tree.cc Outdated Show resolved Hide resolved
cpp/serve/radix_tree.cc Outdated Show resolved Hide resolved
@@ -115,6 +121,9 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
// Speculative models shift left the input tokens by 1 when base model has committed tokens.
// Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens.
for (int j = 0; j < static_cast<int>(input_data.size()); ++j) {
if (!model_id) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_id != 0 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, here does not matter if model == 0. And prefix cache only cares about. base model(model_id == 0), so we only update if model_id == 0. However, it will be changed if input data logic refactored soon, by using offset instead.

@@ -33,6 +33,9 @@ class Request(Object):
The JSON string of the default generation config.
When a field in the input generation_config is not defined,
we use the value in the default generation config.

pinned : bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us avoid exposing pinned for now to the user side

@@ -97,6 +97,10 @@ class EngineImpl : public Engine {
return TResult::Error(engine_config_res.UnwrapErr());
}
EngineConfig engine_config = engine_config_res.Unwrap();
n->estate_->prefix_cache = PrefixCache::Init(
engine_config->max_total_sequence_length / engine_config->kv_cache_page_size,
engine_config->kv_cache_page_size * 16, engine_config->max_num_sequence,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment where 16 comes from

Copy link
Contributor

@tqchen tqchen May 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agrree we should avoid magic numbers, have constant name

* tokens. If the request state entry is not added to KVCache yet, this method will add/fork the
* request in the KVCache, depending on the matching result from prefix cache.
* \param estate The engine state.
* \param[out] input The prefill input to be matched and updated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* \param[out] input The prefill input to be matched and updated.
* \param[in,out] input The prefill input to be matched and updated.

there are also a few other places need updates

@@ -50,8 +64,25 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_
// So we mark the parent entry as finished.
rstate->entries[parent_idx]->status = RequestStateStatus::kFinished;
// Remove the request state entry from all the models.
RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models);
estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id);
if (estate->prefix_cache->HasSequence(rstate->entries[parent_idx]->mstates[0]->internal_id)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very similar to the changes in line 32, it's better extracting to a function or lambda

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making it a function with name

@kripper
Copy link
Contributor

kripper commented May 16, 2024

Are these ideas already considered?
See: #2353

@tqchen tqchen dismissed their stale review May 16, 2024 21:57

comments addressed

@cyx-6
Copy link
Contributor Author

cyx-6 commented May 17, 2024

@kripper Thanks for suggestions, but I think these are not part of this PR. Although this PR does aim to improve multi-round chat, it is in a different way from AttentionStore.

@kripper
Copy link
Contributor

kripper commented May 17, 2024

@kripper Thanks for suggestions, but I think these are not part of this PR. Although this PR does aim to improve multi-round chat, it is in a different way from AttentionStore.

Ok. I will leave the feature request there.

@@ -97,6 +97,10 @@ class EngineImpl : public Engine {
return TResult::Error(engine_config_res.UnwrapErr());
}
EngineConfig engine_config = engine_config_res.Unwrap();
n->estate_->prefix_cache = PrefixCache::Init(
engine_config->max_total_sequence_length / engine_config->kv_cache_page_size,
engine_config->kv_cache_page_size * 16, engine_config->max_num_sequence,
Copy link
Contributor

@tqchen tqchen May 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agrree we should avoid magic numbers, have constant name

* \param lazy The flag if the sequence should be removed lazily or intermediary.
* \throw Error if the given sequence id is not valid.
*/
virtual void RecycleSequence(int64_t seq_id, PackedFunc callback, bool lazy = true) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a signature to callback, in this can TypedPackedFunc is better

* \param sliding_window_size The sliding window size, -1 for disabled sliding window.
* \param attention_sink_size The attention sink position for sliding window.
*/
static PrefixCache Init(size_t num_pages, size_t page_size, size_t num_seqs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init => Create. In this case, perhaps we can change to constructor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming:

  • num_seqs => max_num_seqs
  • num_pages => max_num_pages

}),
/*lazy=*/true);
}
// If the request is pinned, do nothing over the prefix cache and KVCache. Let the data be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not make data orphan, instead, add a new type of state(besides active etc, make the kind is SystemKeepAlive)

@@ -50,8 +64,25 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_
// So we mark the parent entry as finished.
rstate->entries[parent_idx]->status = RequestStateStatus::kFinished;
// Remove the request state entry from all the models.
RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models);
estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id);
if (estate->prefix_cache->HasSequence(rstate->entries[parent_idx]->mstates[0]->internal_id)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making it a function with name

@@ -148,6 +148,7 @@ class ChatCompletionRequest {
std::optional<std::string> user = std::nullopt;
bool ignore_eos = false;
// RequestResponseFormat response_format; //TODO: implement this
bool pinned = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us not expose pinned for now in the JSON FFI as it is not necessary

@@ -84,7 +84,7 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
request.top_logprobs, request.logit_bias, request.seed,
request.ignore_eos, request.max_tokens, std::move(stop_strs),
conv_template_.stop_token_ids, /*response_format=*/std::nullopt,
this->default_generation_cfg_json_str_);
request.pinned, this->default_generation_cfg_json_str_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json FFI should not expose pinned for now

@@ -50,6 +50,7 @@ class GenerationConfigNode : public Object {
std::vector<int> stop_token_ids;

ResponseFormat response_format;
bool pinned = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us add a sub structure, DebugConfig, which contains debug related options that should not be exposed to the end point.

Debug config should include two fields:

  • ignore_eos
  • pin_system_prompt

@@ -93,6 +93,11 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes

response_format : ResponseFormat
The response format of the generation output.

pinned : bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a subdataclass debug_config here. to include

  • pin_system_prompt
  • ignore_eos

/*!
* \brief The core data structure radix tree.
*/
PagedRadixTree radix_tree;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code style: class member must ends with underscore, PagedRadixTree radix_tree_, this is for better readability

* \param sliding_window_size The sliding window size, -1 for disabled sliding window.
* \param attention_sink_size The attention sink position for sliding window.
*/
static PrefixCache Init(size_t num_pages, size_t page_size, size_t num_seqs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming:

  • num_seqs => max_num_seqs
  • num_pages => max_num_pages

NVTXScopedRange nvtx_scope("EngineAction postproc");
std::vector<RequestStateEntry> finished_rsentries;
finished_rsentries.reserve(requests.size());

Array<RequestStreamOutput> callback_delta_outputs;
callback_delta_outputs.reserve(requests.size());

for (Request request : requests) {
RequestState rstate = estate->GetRequestState(request);
for (const RequestStateEntry& rsentry : rstate->entries) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lift to a sub function

* \param estate The engine state.
* \param[out] input The prefill input to be matched and updated.
*/
void MatchPrefixCache(EngineState estate, PrefillInput& input) final {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass by *

class PrefixCacheObj : public Object {
public:
/*!
* \brief Insert a new tokenized sequence into Prefix Cache.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function updates the PrefixCache state, users must create a MatchAndReusePrefixCache function, that take the result and perform related update

* \brief The parent sequence ID to fork in KVCache. The default value if -1, which means no
* forking operation needed.
*/
int64_t parent_seq_id = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fork_from_seq_id

*/
int64_t parent_seq_id = -1;
/*!
* \brief The matched prefix offset, which should be skipped when prefilling.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The matched prefix offset, which can be used to guide how to fork the parent seq

* \param estate The engine state.
* \param[out] input The prefill input to be matched and updated.
*/
void MatchPrefixCache(EngineState estate, PrefillInput& input) final {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatchAndReusePrefixCahe

* \param sliding_window_size The sliding window size, -1 for disabled sliding window.
* \param attention_sink_size The attention sink position for sliding window.
*/
explicit PrefixCacheImpl(size_t num_pages, size_t page_size, size_t num_seqs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix_cahce_max_num_seqs

cyx-6 added 3 commits May 18, 2024 21:43
This PR introduces the prefix cache into serving engine, to manage prefix and accelerate prefill process.
@cyx-6 cyx-6 force-pushed the prefix-cache-2 branch 2 times, most recently from 80f5d03 to 74fcee5 Compare May 18, 2024 23:08
@@ -78,13 +78,19 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
}
}

bool pinned_system_prompt = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should pass in debug_config into serve::DebugConfig directly, instead of pas sby per value

@@ -50,6 +55,7 @@ class GenerationConfigNode : public Object {
std::vector<int> stop_token_ids;

ResponseFormat response_format;
DebugConfig debug_config = {false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make DebugConfig optional

@@ -30,6 +30,11 @@ struct ResponseFormat {
Optional<String> schema = NullOpt;
};

/*! \brief The debug configuration of a request. */
struct DebugConfig {
bool pinned_system_prompt = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make DebugConfig an Object

SequenceIDNode* raw_pool_;
/*! \brief The sequence ID node pool. */
/*! \brief The size of each node pool block. */
static constexpr size_t NODE_BLOCK_SIZE_ = 64;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kNodeBlockSize

@tqchen tqchen merged commit 5444fd5 into mlc-ai:main May 21, 2024
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants