Skip to content

Commit

Permalink
适配 google gemini 优化为从用户input中提取文件 (#1419)
Browse files Browse the repository at this point in the history
适配 google gemini 优化为从用户input中提取文件
  • Loading branch information
Kilig947 authored and binary-sky committed Dec 31, 2023
1 parent a96f842 commit 5a12f6e
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 95 deletions.
17 changes: 13 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
"chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
"gemini-pro", "chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 [
# "qwen-turbo", "qwen-plus", "qwen-max"
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
# “qwen-turbo", "qwen-plus", "qwen-max"]
# ]


# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
Expand Down Expand Up @@ -204,6 +206,10 @@
CUSTOM_API_KEY_PATTERN = ""


# Google Gemini API-Key
GEMINI_API_KEY = ''


# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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"

Expand Down Expand Up @@ -292,6 +298,9 @@
├── "qwen-turbo" 等通义千问大模型
│ └── DASHSCOPE_API_KEY
├── "Gemini"
│ └── GEMINI_API_KEY
└── "newbing" Newbing接口不再稳定,不推荐使用
├── NEWBING_STYLE
└── NEWBING_COOKIES
Expand Down
19 changes: 19 additions & 0 deletions request_llms/bridge_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
from .bridge_qianfan import predict as qianfan_ui

from .bridge_google_gemini import predict as genai_ui
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui

colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']

class LazyloadTiktoken(object):
Expand Down Expand Up @@ -246,6 +249,22 @@ def decode(self, *args, **kwargs):
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro-vision": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
}

# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-
Expand Down
101 changes: 101 additions & 0 deletions request_llms/bridge_google_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# encoding: utf-8
# @Time : 2023/12/21
# @Author : Spike
# @Descr :
import json
import re
import time
from request_llms.com_google import GoogleChatInit
from toolbox import get_conf, update_ui, update_ui_lastest_msg

proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'


def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
console_slience=False):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
raise ValueError(f"请配置 GEMINI_API_KEY。")

genai = GoogleChatInit()
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
gpt_replying_buffer = ''
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
for response in stream_response:
results = response.decode()
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
buffer = paraphrase['text']
gpt_replying_buffer += buffer
if len(observe_window) >= 1:
observe_window[0] = gpt_replying_buffer
if len(observe_window) >= 2:
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
if error_match:
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
return gpt_replying_buffer


def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
return

chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history)
genai = GoogleChatInit()
retry = 0
while True:
try:
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
break
except Exception as e:
retry += 1
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError
gpt_replying_buffer = ""
gpt_security_policy = ""
history.extend([inputs, ''])
for response in stream_response:
results = response.decode("utf-8") # 被这个解码给耍了。。
gpt_security_policy += results
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
chatbot[-1] = (inputs, gpt_replying_buffer)
history[-1] = gpt_replying_buffer
yield from update_ui(chatbot=chatbot, history=history)
if error_match:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
yield from update_ui(chatbot=chatbot, history=history)
raise RuntimeError('对话错误')
if not gpt_replying_buffer:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
yield from update_ui(chatbot=chatbot, history=history)



if __name__ == '__main__':
import sys

llm_kwargs = {'llm_model': 'gemini-pro'}
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
for i in result:
print(i)
86 changes: 86 additions & 0 deletions request_llms/com_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# encoding: utf-8
# @Time : 2023/12/25
# @Author : Spike
# @Descr :
import json
import requests
from typing import List, Dict, Tuple
from toolbox import get_conf

proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')

class GoogleChatInit:

def __init__(self):
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'

def __conversation_user(self, user_input):
from toolbox import input_encode_handler
what_i_have_asked = {"role": "user", "parts": []}
if 'vision' not in self.url_gemini:
input_ = user_input
encode_img = []
else:
input_, encode_img = input_encode_handler(user_input)
what_i_have_asked['parts'].append({'text': input_})
if encode_img:
for data in encode_img:
what_i_have_asked['parts'].append(
{'inline_data': {
"mime_type": f"image/{data['type']}",
"data": data['data']
}})
return what_i_have_asked

def __conversation_history(self, history):
messages = []
conversation_cnt = len(history) // 2
if conversation_cnt:
for index in range(0, 2 * conversation_cnt, 2):
what_i_have_asked = self.__conversation_user(history[index])
what_gpt_answer = {
"role": "model",
"parts": [{"text": history[index + 1]}]
}
messages.append(what_i_have_asked)
messages.append(what_gpt_answer)
return messages

def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
return response.iter_lines()

def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
messages = [
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
# {"role": "user", "parts": [{"text": ""}]},
# {"role": "model", "parts": [{"text": ""}]}
]
self.url_gemini = self.url_gemini.replace(
'%m', llm_kwargs['llm_model']).replace(
'%k', get_conf('GEMINI_API_KEY')
)
header = {'Content-Type': 'application/json'}
if 'vision' not in self.url_gemini: # 不是vision 才处理history
messages.extend(self.__conversation_history(history)) # 处理 history
messages.append(self.__conversation_user(inputs)) # 处理用户对话
payload = {
"contents": messages,
"generationConfig": {
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
"temperature": llm_kwargs.get('temperature', 1),
# "maxOutputTokens": 800,
"topP": llm_kwargs.get('top_p', 0.8),
"topK": 10
}
}
return header, payload


if __name__ == '__main__':
google = GoogleChatInit()
# print(gootle.generate_message_payload('你好呀', {},
# ['123123', '3123123'], ''))
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')

0 comments on commit 5a12f6e

Please sign in to comment.