-
Notifications
You must be signed in to change notification settings - Fork 7.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
适配 google gemini 优化为从用户input中提取文件 (#1419)
适配 google gemini 优化为从用户input中提取文件
- Loading branch information
1 parent
a96f842
commit a7c960d
Showing
5 changed files
with
472 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# encoding: utf-8 | ||
# @Time : 2023/12/21 | ||
# @Author : Spike | ||
# @Descr : | ||
import json | ||
import re | ||
import time | ||
from request_llms.com_google import GoogleChatInit | ||
from toolbox import get_conf, update_ui, update_ui_lastest_msg | ||
|
||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY') | ||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ | ||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' | ||
|
||
|
||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, | ||
console_slience=False): | ||
# 检查API_KEY | ||
if get_conf("GEMINI_API_KEY") == "": | ||
raise ValueError(f"请配置 GEMINI_API_KEY。") | ||
|
||
genai = GoogleChatInit() | ||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可 | ||
gpt_replying_buffer = '' | ||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt) | ||
for response in stream_response: | ||
results = response.decode() | ||
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) | ||
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL) | ||
if match: | ||
try: | ||
paraphrase = json.loads('{"text": "%s"}' % match.group(1)) | ||
except: | ||
raise ValueError(f"解析GEMINI消息出错。") | ||
buffer = paraphrase['text'] | ||
gpt_replying_buffer += buffer | ||
if len(observe_window) >= 1: | ||
observe_window[0] = gpt_replying_buffer | ||
if len(observe_window) >= 2: | ||
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。") | ||
if error_match: | ||
raise RuntimeError(f'{gpt_replying_buffer} 对话错误') | ||
return gpt_replying_buffer | ||
|
||
|
||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None): | ||
# 检查API_KEY | ||
if get_conf("GEMINI_API_KEY") == "": | ||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0) | ||
return | ||
|
||
chatbot.append((inputs, "")) | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
genai = GoogleChatInit() | ||
retry = 0 | ||
while True: | ||
try: | ||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt) | ||
break | ||
except Exception as e: | ||
retry += 1 | ||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg)) | ||
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else "" | ||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面 | ||
if retry > MAX_RETRY: raise TimeoutError | ||
gpt_replying_buffer = "" | ||
gpt_security_policy = "" | ||
history.extend([inputs, '']) | ||
for response in stream_response: | ||
results = response.decode("utf-8") # 被这个解码给耍了。。 | ||
gpt_security_policy += results | ||
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) | ||
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL) | ||
if match: | ||
try: | ||
paraphrase = json.loads('{"text": "%s"}' % match.group(1)) | ||
except: | ||
raise ValueError(f"解析GEMINI消息出错。") | ||
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理 | ||
chatbot[-1] = (inputs, gpt_replying_buffer) | ||
history[-1] = gpt_replying_buffer | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
if error_match: | ||
history = history[-2] # 错误的不纳入对话 | ||
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```") | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
raise RuntimeError('对话错误') | ||
if not gpt_replying_buffer: | ||
history = history[-2] # 错误的不纳入对话 | ||
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```") | ||
yield from update_ui(chatbot=chatbot, history=history) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
|
||
llm_kwargs = {'llm_model': 'gemini-pro'} | ||
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, []) | ||
for i in result: | ||
print(i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# encoding: utf-8 | ||
# @Time : 2023/12/25 | ||
# @Author : Spike | ||
# @Descr : | ||
import json | ||
import os | ||
import re | ||
import requests | ||
from typing import List, Dict, Tuple | ||
from toolbox import get_conf, encode_image | ||
|
||
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS') | ||
|
||
""" | ||
======================================================================== | ||
第五部分 一些文件处理方法 | ||
files_filter_handler 根据type过滤文件 | ||
input_encode_handler 提取input中的文件,并解析 | ||
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本 | ||
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件 | ||
html_view_blank 超链接 | ||
html_local_file 本地文件取相对路径 | ||
to_markdown_tabs 文件list 转换为 md tab | ||
""" | ||
|
||
|
||
def files_filter_handler(file_list): | ||
new_list = [] | ||
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps'] | ||
for file in file_list: | ||
file = str(file).replace('file=', '') | ||
if os.path.exists(file): | ||
if str(os.path.basename(file)).split('.')[-1] in filter_: | ||
new_list.append(file) | ||
return new_list | ||
|
||
|
||
def input_encode_handler(inputs): | ||
md_encode = [] | ||
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))" | ||
matches_path = re.findall(pattern_md_file, inputs) | ||
for md_path in matches_path: | ||
pattern_file = r"\((file=.*)\)" | ||
matches_path = re.findall(pattern_file, md_path) | ||
encode_file = files_filter_handler(file_list=matches_path) | ||
if encode_file: | ||
md_encode.extend([{ | ||
"data": encode_image(i), | ||
"type": os.path.splitext(i)[1].replace('.', '') | ||
} for i in encode_file]) | ||
inputs = inputs.replace(md_path, '') | ||
return inputs, md_encode | ||
|
||
|
||
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False): | ||
new_list = [] | ||
if not filter_: | ||
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps'] | ||
for file in file_list: | ||
if str(os.path.basename(file)).split('.')[-1] in filter_: | ||
new_list.append(html_local_img(file, md=md_type)) | ||
elif os.path.exists(file): | ||
new_list.append(link_mtime_to_md(file)) | ||
else: | ||
new_list.append(file) | ||
return new_list | ||
|
||
|
||
def link_mtime_to_md(file): | ||
link_local = html_local_file(file) | ||
link_name = os.path.basename(file) | ||
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})" | ||
return a | ||
|
||
|
||
def html_local_file(file): | ||
base_path = os.path.dirname(__file__) # 项目目录 | ||
if os.path.exists(str(file)): | ||
file = f'file={file.replace(base_path, ".")}' | ||
return file | ||
|
||
|
||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True): | ||
style = '' | ||
if max_width is not None: | ||
style += f"max-width: {max_width};" | ||
if max_height is not None: | ||
style += f"max-height: {max_height};" | ||
__file = html_local_file(__file) | ||
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>' | ||
if md: | ||
a = f'![{__file}]({__file})' | ||
return a | ||
|
||
|
||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False): | ||
""" | ||
Args: | ||
head: 表头:[] | ||
tabs: 表值:[[列1], [列2], [列3], [列4]] | ||
alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐 | ||
column: True to keep data in columns, False to keep data in rows (default). | ||
Returns: | ||
A string representation of the markdown table. | ||
""" | ||
if column: | ||
transposed_tabs = list(map(list, zip(*tabs))) | ||
else: | ||
transposed_tabs = tabs | ||
# Find the maximum length among the columns | ||
max_len = max(len(column) for column in transposed_tabs) | ||
|
||
tab_format = "| %s " | ||
tabs_list = "".join([tab_format % i for i in head]) + '|\n' | ||
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n' | ||
|
||
for i in range(max_len): | ||
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs] | ||
row_data = file_manifest_filter_html(row_data, filter_=None) | ||
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n' | ||
|
||
return tabs_list | ||
|
||
|
||
class GoogleChatInit: | ||
|
||
def __init__(self): | ||
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k' | ||
|
||
def __conversation_user(self, user_input): | ||
what_i_have_asked = {"role": "user", "parts": []} | ||
if 'vision' not in self.url_gemini: | ||
input_ = user_input | ||
encode_img = [] | ||
else: | ||
input_, encode_img = input_encode_handler(user_input) | ||
what_i_have_asked['parts'].append({'text': input_}) | ||
if encode_img: | ||
for data in encode_img: | ||
what_i_have_asked['parts'].append( | ||
{'inline_data': { | ||
"mime_type": f"image/{data['type']}", | ||
"data": data['data'] | ||
}}) | ||
return what_i_have_asked | ||
|
||
def __conversation_history(self, history): | ||
messages = [] | ||
conversation_cnt = len(history) // 2 | ||
if conversation_cnt: | ||
for index in range(0, 2 * conversation_cnt, 2): | ||
what_i_have_asked = self.__conversation_user(history[index]) | ||
what_gpt_answer = { | ||
"role": "model", | ||
"parts": [{"text": history[index + 1]}] | ||
} | ||
messages.append(what_i_have_asked) | ||
messages.append(what_gpt_answer) | ||
return messages | ||
|
||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt): | ||
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt) | ||
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload), | ||
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS) | ||
return response.iter_lines() | ||
|
||
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]: | ||
messages = [ | ||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。 | ||
# {"role": "user", "parts": [{"text": ""}]}, | ||
# {"role": "model", "parts": [{"text": ""}]} | ||
] | ||
self.url_gemini = self.url_gemini.replace( | ||
'%m', llm_kwargs['llm_model']).replace( | ||
'%k', get_conf('GEMINI_API_KEY') | ||
) | ||
header = {'Content-Type': 'application/json'} | ||
if 'vision' not in self.url_gemini: # 不是vision 才处理history | ||
messages.extend(self.__conversation_history(history)) # 处理 history | ||
messages.append(self.__conversation_user(inputs)) # 处理用户对话 | ||
payload = { | ||
"contents": messages, | ||
"generationConfig": { | ||
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '), | ||
"temperature": llm_kwargs.get('temperature', 1), | ||
# "maxOutputTokens": 800, | ||
"topP": llm_kwargs.get('top_p', 0.8), | ||
"topK": 10 | ||
} | ||
} | ||
return header, payload | ||
|
||
|
||
if __name__ == '__main__': | ||
google = GoogleChatInit() | ||
# print(gootle.generate_message_payload('你好呀', {}, | ||
# ['123123', '3123123'], '')) | ||
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)') |
Oops, something went wrong.