Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support vision models from xinference #4094

Merged
merged 5 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 63 additions & 30 deletions api/core/model_runtime/model_providers/xinference/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
Expand Down Expand Up @@ -61,8 +64,8 @@


class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
Expand Down Expand Up @@ -99,7 +102,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
Expand All @@ -111,10 +114,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')

if extra_param.support_function_call:
credentials['support_function_call'] = True

if extra_param.support_vision:
credentials['support_vision'] = True

if extra_param.context_length:
credentials['context_length'] = extra_param.context_length

Expand All @@ -135,7 +141,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr
"""
return self._num_tokens_from_messages(prompt_messages, tools)

def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
is_completion_model: bool = False) -> int:
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
Expand All @@ -155,7 +161,7 @@ def tokens(text: str):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item.text
text += item['text']

value = text

Expand Down Expand Up @@ -191,7 +197,7 @@ def tokens(text: str):
num_tokens += self._num_tokens_for_tools(tools)

return num_tokens

def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
Expand Down Expand Up @@ -234,7 +240,7 @@ def tokens(text: str):
num_tokens += tokens(required_field)

return num_tokens

def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
"""
convert prompt message to text
Expand All @@ -260,7 +266,26 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
Expand All @@ -277,7 +302,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")

return message_dict

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
Expand Down Expand Up @@ -338,8 +363,18 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')



features = []

support_function_call = credentials.get('support_function_call', False)
if support_function_call:
features.append(ModelFeature.TOOL_CALL)

support_vision = credentials.get('support_vision', False)
if support_vision:
features.append(ModelFeature.VISION)

context_length = credentials.get('context_length', 2048)

entity = AIModelEntity(
Expand All @@ -349,33 +384,31 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={
features=features,
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
parameter_rules=rules
)

return entity
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],

def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
tools: list[PromptMessageTool] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
generate text from LLM

see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`

extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')

if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]

Expand Down Expand Up @@ -408,11 +441,11 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM
'function': helper.dump_model(tool)
} for tool in tools
]

vision = credentials.get('support_vision', False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
resp = client.chat.completions.create(
model=credentials['model_uid'],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
stream=stream,
user=user,
**generate_config,
Expand Down Expand Up @@ -497,7 +530,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m
"""
if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response")

assistant_message = resp.choices[0].message

# convert tool call to assistant message tool call
Expand Down Expand Up @@ -527,7 +560,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m
)

return response

def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Iterator[ChatCompletionChunk]) -> Generator:
Expand All @@ -544,7 +577,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes

if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
continue

# check if there is a tool call in the response
function_call = None
tool_calls = []
Expand Down Expand Up @@ -573,9 +606,9 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])

usage = self._calc_response_usage(model=model, credentials=credentials,
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
Expand Down Expand Up @@ -608,7 +641,7 @@ def _handle_completion_generate_response(self, model: str, credentials: dict, pr
"""
if len(resp.choices) == 0:
raise InvokeServerUnavailableError("Empty response")

assistant_message = resp.choices[0].text

# transform assistant message to prompt message
Expand Down Expand Up @@ -670,9 +703,9 @@ def _handle_completion_stream_response(self, model: str, credentials: dict, prom
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
)
usage = self._calc_response_usage(model=model, credentials=credentials,
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ class XinferenceModelExtraParameter:
max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False
support_vision: bool = False

def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
support_function_call: bool, max_tokens: int, context_length: int) -> None:
def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
self.model_format = model_format
self.model_handle_type = model_handle_type
self.model_ability = model_ability
self.support_function_call = support_function_call
self.support_vision = support_vision
self.max_tokens = max_tokens
self.context_length = context_length

Expand Down Expand Up @@ -71,7 +73,7 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')

response_json = response.json()

model_format = response_json.get('model_format', 'ggmlv3')
Expand All @@ -87,17 +89,19 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
model_handle_type = 'chat'
else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')

support_function_call = 'tools' in model_ability
support_vision = 'vision' in model_ability
max_tokens = response_json.get('max_tokens', 512)

context_length = response_json.get('context_length', 2048)

return XinferenceModelExtraParameter(
model_format=model_format,
model_handle_type=model_handle_type,
model_ability=model_ability,
support_function_call=support_function_call,
support_vision=support_vision,
max_tokens=max_tokens,
context_length=context_length
)