Skip to content

Commit

Permalink
update model_provider jina to support custom url and model
Browse files Browse the repository at this point in the history
  • Loading branch information
Gimling committed May 7, 2024
1 parent 6f19115 commit 703967f
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 1 deletion.
75 changes: 75 additions & 0 deletions api/core/model_runtime/model_providers/azure_openai/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,81 @@ class AzureBaseModel(BaseModel):
)
)
),
AzureBaseModel(
base_model_name='gpt-4-turbo-2024-04-09',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
),
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
),
],
pricing=PriceConfig(
input=0.01,
output=0.03,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='gpt-4-1106-preview',
entity=AIModelEntity(
Expand Down
30 changes: 30 additions & 0 deletions api/core/model_runtime/model_providers/jina/jina.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ supported_model_types:
- rerank
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
Expand All @@ -29,3 +30,32 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: base_url
label:
zh_Hans: 服务器URL
en_US: Base URL
type: text-input
required: true
placeholder:
zh_Hans: Base URL, e.g. https://api.jina.ai/v1
en_US: Base URL, e.g. https://api.jina.ai/v1
default: 'https://api.jina.ai/v1'
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
placeholder:
zh_Hans: 输入上下文大小
en_US: Enter context size
required: false
type: text-input
default: '8192'
25 changes: 24 additions & 1 deletion api/core/model_runtime/model_providers/jina/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional
from urllib.parse import urljoin

import httpx

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand Down Expand Up @@ -38,9 +41,13 @@ def _invoke(self, model: str, credentials: dict,
if len(docs) == 0:
return RerankResult(model=model, docs=[])

base_url = credentials['base_url']
if not base_url.endswith('/'):
base_url += '/'

try:
response = httpx.post(
"https://api.jina.ai/v1/rerank",
urljoin(base_url, 'rerank'),
json={
"model": model,
"query": query,
Expand Down Expand Up @@ -103,3 +110,19 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
}
)

return entity

0 comments on commit 703967f

Please sign in to comment.