Skip to content

Commit

Permalink
Merge pull request #3527 from zhaonx/dev
Browse files Browse the repository at this point in the history
"add support for vllm api stop parameter"
  • Loading branch information
hiyouga committed May 6, 2024
2 parents d6ca785 + 17d0005 commit bcf7ec5
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/llmtuner/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async def create_chat_completion_response(
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
stop=request.stop
)

prompt_length, response_length = 0, 0
Expand Down Expand Up @@ -155,6 +156,7 @@ async def create_stream_chat_completion_response(
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
stop=request.stop
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(
Expand Down
3 changes: 2 additions & 1 deletion src/llmtuner/api/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from enum import Enum, unique
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field
from typing_extensions import Literal
Expand Down Expand Up @@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel):
n: int = 1
max_tokens: Optional[int] = None
stream: bool = False
stop: Union[Optional[str], List[str]] = None


class ChatCompletionResponseChoice(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/llmtuner/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def _generate(
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None)

generating_args = self.generating_args.copy()
generating_args.update(
Expand All @@ -122,6 +123,7 @@ async def _generate(
top_k=generating_args["top_k"],
use_beam_search=generating_args["num_beams"] > 1,
length_penalty=generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
Expand Down

0 comments on commit bcf7ec5

Please sign in to comment.