Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update model_provider jina to support custom url and model #4110

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions api/core/model_runtime/model_providers/jina/jina.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ supported_model_types:
- rerank
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
Expand All @@ -29,3 +30,40 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: base_url
label:
zh_Hans: 服务器 URL
en_US: Base URL
type: text-input
required: true
placeholder:
zh_Hans: Base URL, e.g. https://api.jina.ai/v1
en_US: Base URL, e.g. https://api.jina.ai/v1
default: 'https://api.jina.ai/v1'
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
placeholder:
zh_Hans: 输入上下文大小
en_US: Enter context size
required: false
type: text-input
default: '8192'
24 changes: 23 additions & 1 deletion api/core/model_runtime/model_providers/jina/rerank/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import httpx

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand Down Expand Up @@ -38,9 +40,13 @@ def _invoke(self, model: str, credentials: dict,
if len(docs) == 0:
return RerankResult(model=model, docs=[])

base_url = credentials.get('base_url', 'https://api.jina.ai/v1')
if base_url.endswith('/'):
base_url = base_url[:-1]

try:
response = httpx.post(
"https://api.jina.ai/v1/rerank",
base_url + '/rerank',
json={
"model": model,
"query": query,
Expand Down Expand Up @@ -103,3 +109,19 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
}
)

return entity
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from requests import post

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand All @@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Jina text embedding model.
"""
api_base: str = 'https://api.jina.ai/v1/embeddings'
models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
api_base: str = 'https://api.jina.ai/v1'

def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
Expand All @@ -39,11 +39,14 @@ def _invoke(self, model: str, credentials: dict,
:return: embeddings result
"""
api_key = credentials['api_key']
if model not in self.models:
raise InvokeBadRequestError('Invalid model name')
if not api_key:
raise CredentialsValidateFailedError('api_key is required')
url = self.api_base

base_url = credentials.get('base_url', self.api_base)
if base_url.endswith('/'):
base_url = base_url[:-1]

url = base_url + '/embeddings'
headers = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
Expand All @@ -70,7 +73,7 @@ def _invoke(self, model: str, credentials: dict,
elif response.status_code == 500:
raise InvokeServerUnavailableError(msg)
else:
raise InvokeError(msg)
raise InvokeBadRequestError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")

Expand Down Expand Up @@ -118,8 +121,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError:
raise CredentialsValidateFailedError('Invalid api key')
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
Expand All @@ -137,7 +140,8 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
KeyError,
InvokeBadRequestError
]
}

Expand Down Expand Up @@ -170,3 +174,19 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em
)

return usage

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
}
)

return entity