-
Notifications
You must be signed in to change notification settings - Fork 7.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
适配 google gemini 优化为从用户input中提取文件 (#1419)
适配 google gemini 优化为从用户input中提取文件
- Loading branch information
1 parent
a96f842
commit 5a12f6e
Showing
5 changed files
with
360 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# encoding: utf-8 | ||
# @Time : 2023/12/21 | ||
# @Author : Spike | ||
# @Descr : | ||
import json | ||
import re | ||
import time | ||
from request_llms.com_google import GoogleChatInit | ||
from toolbox import get_conf, update_ui, update_ui_lastest_msg | ||
|
||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY') | ||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ | ||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' | ||
|
||
|
||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, | ||
console_slience=False): | ||
# 检查API_KEY | ||
if get_conf("GEMINI_API_KEY") == "": | ||
raise ValueError(f"请配置 GEMINI_API_KEY。") | ||
|
||
genai = GoogleChatInit() | ||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可 | ||
gpt_replying_buffer = '' | ||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt) | ||
for response in stream_response: | ||
results = response.decode() | ||
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) | ||
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL) | ||
if match: | ||
try: | ||
paraphrase = json.loads('{"text": "%s"}' % match.group(1)) | ||
except: | ||
raise ValueError(f"解析GEMINI消息出错。") | ||
buffer = paraphrase['text'] | ||
gpt_replying_buffer += buffer | ||
if len(observe_window) >= 1: | ||
observe_window[0] = gpt_replying_buffer | ||
if len(observe_window) >= 2: | ||
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。") | ||
if error_match: | ||
raise RuntimeError(f'{gpt_replying_buffer} 对话错误') | ||
return gpt_replying_buffer | ||
|
||
|
||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None): | ||
# 检查API_KEY | ||
if get_conf("GEMINI_API_KEY") == "": | ||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0) | ||
return | ||
|
||
chatbot.append((inputs, "")) | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
genai = GoogleChatInit() | ||
retry = 0 | ||
while True: | ||
try: | ||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt) | ||
break | ||
except Exception as e: | ||
retry += 1 | ||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg)) | ||
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else "" | ||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面 | ||
if retry > MAX_RETRY: raise TimeoutError | ||
gpt_replying_buffer = "" | ||
gpt_security_policy = "" | ||
history.extend([inputs, '']) | ||
for response in stream_response: | ||
results = response.decode("utf-8") # 被这个解码给耍了。。 | ||
gpt_security_policy += results | ||
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) | ||
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL) | ||
if match: | ||
try: | ||
paraphrase = json.loads('{"text": "%s"}' % match.group(1)) | ||
except: | ||
raise ValueError(f"解析GEMINI消息出错。") | ||
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理 | ||
chatbot[-1] = (inputs, gpt_replying_buffer) | ||
history[-1] = gpt_replying_buffer | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
if error_match: | ||
history = history[-2] # 错误的不纳入对话 | ||
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```") | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
raise RuntimeError('对话错误') | ||
if not gpt_replying_buffer: | ||
history = history[-2] # 错误的不纳入对话 | ||
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```") | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
|
||
llm_kwargs = {'llm_model': 'gemini-pro'} | ||
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, []) | ||
for i in result: | ||
print(i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# encoding: utf-8 | ||
# @Time : 2023/12/25 | ||
# @Author : Spike | ||
# @Descr : | ||
import json | ||
import requests | ||
from typing import List, Dict, Tuple | ||
from toolbox import get_conf | ||
|
||
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS') | ||
|
||
class GoogleChatInit: | ||
|
||
def __init__(self): | ||
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k' | ||
|
||
def __conversation_user(self, user_input): | ||
from toolbox import input_encode_handler | ||
what_i_have_asked = {"role": "user", "parts": []} | ||
if 'vision' not in self.url_gemini: | ||
input_ = user_input | ||
encode_img = [] | ||
else: | ||
input_, encode_img = input_encode_handler(user_input) | ||
what_i_have_asked['parts'].append({'text': input_}) | ||
if encode_img: | ||
for data in encode_img: | ||
what_i_have_asked['parts'].append( | ||
{'inline_data': { | ||
"mime_type": f"image/{data['type']}", | ||
"data": data['data'] | ||
}}) | ||
return what_i_have_asked | ||
|
||
def __conversation_history(self, history): | ||
messages = [] | ||
conversation_cnt = len(history) // 2 | ||
if conversation_cnt: | ||
for index in range(0, 2 * conversation_cnt, 2): | ||
what_i_have_asked = self.__conversation_user(history[index]) | ||
what_gpt_answer = { | ||
"role": "model", | ||
"parts": [{"text": history[index + 1]}] | ||
} | ||
messages.append(what_i_have_asked) | ||
messages.append(what_gpt_answer) | ||
return messages | ||
|
||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt): | ||
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt) | ||
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload), | ||
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS) | ||
return response.iter_lines() | ||
|
||
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]: | ||
messages = [ | ||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。 | ||
# {"role": "user", "parts": [{"text": ""}]}, | ||
# {"role": "model", "parts": [{"text": ""}]} | ||
] | ||
self.url_gemini = self.url_gemini.replace( | ||
'%m', llm_kwargs['llm_model']).replace( | ||
'%k', get_conf('GEMINI_API_KEY') | ||
) | ||
header = {'Content-Type': 'application/json'} | ||
if 'vision' not in self.url_gemini: # 不是vision 才处理history | ||
messages.extend(self.__conversation_history(history)) # 处理 history | ||
messages.append(self.__conversation_user(inputs)) # 处理用户对话 | ||
payload = { | ||
"contents": messages, | ||
"generationConfig": { | ||
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '), | ||
"temperature": llm_kwargs.get('temperature', 1), | ||
# "maxOutputTokens": 800, | ||
"topP": llm_kwargs.get('top_p', 0.8), | ||
"topK": 10 | ||
} | ||
} | ||
return header, payload | ||
|
||
|
||
if __name__ == '__main__': | ||
google = GoogleChatInit() | ||
# print(gootle.generate_message_payload('你好呀', {}, | ||
# ['123123', '3123123'], '')) | ||
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)') |
Oops, something went wrong.