Skip to content

Commit

Permalink
适配 google gemini 优化为从用户input中提取文件 (#1419)
Browse files Browse the repository at this point in the history
适配 google gemini 优化为从用户input中提取文件
  • Loading branch information
Kilig947 authored and binary-sky committed Dec 31, 2023
1 parent a96f842 commit a7c960d
Show file tree
Hide file tree
Showing 5 changed files with 472 additions and 95 deletions.
17 changes: 13 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
"chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
"gemini-pro", "chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 [
# "qwen-turbo", "qwen-plus", "qwen-max"
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
# “qwen-turbo", "qwen-plus", "qwen-max"]
# ]


# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
Expand Down Expand Up @@ -204,6 +206,10 @@
CUSTOM_API_KEY_PATTERN = ""


# Google Gemini API-Key
GEMINI_API_KEY = ''


# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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"

Expand Down Expand Up @@ -292,6 +298,9 @@
├── "qwen-turbo" 等通义千问大模型
│ └── DASHSCOPE_API_KEY
├── "Gemini"
│ └── GEMINI_API_KEY
└── "newbing" Newbing接口不再稳定,不推荐使用
├── NEWBING_STYLE
└── NEWBING_COOKIES
Expand Down
19 changes: 19 additions & 0 deletions request_llms/bridge_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
from .bridge_qianfan import predict as qianfan_ui

from .bridge_google_gemini import predict as genai_ui
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui

colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']

class LazyloadTiktoken(object):
Expand Down Expand Up @@ -246,6 +249,22 @@ def decode(self, *args, **kwargs):
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro-vision": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
}

# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-
Expand Down
101 changes: 101 additions & 0 deletions request_llms/bridge_google_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# encoding: utf-8
# @Time : 2023/12/21
# @Author : Spike
# @Descr :
import json
import re
import time
from request_llms.com_google import GoogleChatInit
from toolbox import get_conf, update_ui, update_ui_lastest_msg

proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'


def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
console_slience=False):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
raise ValueError(f"请配置 GEMINI_API_KEY。")

genai = GoogleChatInit()
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
gpt_replying_buffer = ''
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
for response in stream_response:
results = response.decode()
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
buffer = paraphrase['text']
gpt_replying_buffer += buffer
if len(observe_window) >= 1:
observe_window[0] = gpt_replying_buffer
if len(observe_window) >= 2:
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
if error_match:
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
return gpt_replying_buffer


def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
return

chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history)
genai = GoogleChatInit()
retry = 0
while True:
try:
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
break
except Exception as e:
retry += 1
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError
gpt_replying_buffer = ""
gpt_security_policy = ""
history.extend([inputs, ''])
for response in stream_response:
results = response.decode("utf-8") # 被这个解码给耍了。。
gpt_security_policy += results
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
chatbot[-1] = (inputs, gpt_replying_buffer)
history[-1] = gpt_replying_buffer
yield from update_ui(chatbot=chatbot, history=history)
if error_match:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
yield from update_ui(chatbot=chatbot, history=history)
raise RuntimeError('对话错误')
if not gpt_replying_buffer:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
yield from update_ui(chatbot=chatbot, history=history)



if __name__ == '__main__':
import sys

llm_kwargs = {'llm_model': 'gemini-pro'}
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
for i in result:
print(i)
198 changes: 198 additions & 0 deletions request_llms/com_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# encoding: utf-8
# @Time : 2023/12/25
# @Author : Spike
# @Descr :
import json
import os
import re
import requests
from typing import List, Dict, Tuple
from toolbox import get_conf, encode_image

proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')

"""
========================================================================
第五部分 一些文件处理方法
files_filter_handler 根据type过滤文件
input_encode_handler 提取input中的文件,并解析
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
html_view_blank 超链接
html_local_file 本地文件取相对路径
to_markdown_tabs 文件list 转换为 md tab
"""


def files_filter_handler(file_list):
new_list = []
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
for file in file_list:
file = str(file).replace('file=', '')
if os.path.exists(file):
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(file)
return new_list


def input_encode_handler(inputs):
md_encode = []
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
matches_path = re.findall(pattern_md_file, inputs)
for md_path in matches_path:
pattern_file = r"\((file=.*)\)"
matches_path = re.findall(pattern_file, md_path)
encode_file = files_filter_handler(file_list=matches_path)
if encode_file:
md_encode.extend([{
"data": encode_image(i),
"type": os.path.splitext(i)[1].replace('.', '')
} for i in encode_file])
inputs = inputs.replace(md_path, '')
return inputs, md_encode


def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
new_list = []
if not filter_:
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
for file in file_list:
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(html_local_img(file, md=md_type))
elif os.path.exists(file):
new_list.append(link_mtime_to_md(file))
else:
new_list.append(file)
return new_list


def link_mtime_to_md(file):
link_local = html_local_file(file)
link_name = os.path.basename(file)
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
return a


def html_local_file(file):
base_path = os.path.dirname(__file__) # 项目目录
if os.path.exists(str(file)):
file = f'file={file.replace(base_path, ".")}'
return file


def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
style = ''
if max_width is not None:
style += f"max-width: {max_width};"
if max_height is not None:
style += f"max-height: {max_height};"
__file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md:
a = f'![{__file}]({__file})'
return a


def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
"""
Args:
head: 表头:[]
tabs: 表值:[[列1], [列2], [列3], [列4]]
alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
column: True to keep data in columns, False to keep data in rows (default).
Returns:
A string representation of the markdown table.
"""
if column:
transposed_tabs = list(map(list, zip(*tabs)))
else:
transposed_tabs = tabs
# Find the maximum length among the columns
max_len = max(len(column) for column in transposed_tabs)

tab_format = "| %s "
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'

for i in range(max_len):
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
row_data = file_manifest_filter_html(row_data, filter_=None)
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'

return tabs_list


class GoogleChatInit:

def __init__(self):
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'

def __conversation_user(self, user_input):
what_i_have_asked = {"role": "user", "parts": []}
if 'vision' not in self.url_gemini:
input_ = user_input
encode_img = []
else:
input_, encode_img = input_encode_handler(user_input)
what_i_have_asked['parts'].append({'text': input_})
if encode_img:
for data in encode_img:
what_i_have_asked['parts'].append(
{'inline_data': {
"mime_type": f"image/{data['type']}",
"data": data['data']
}})
return what_i_have_asked

def __conversation_history(self, history):
messages = []
conversation_cnt = len(history) // 2
if conversation_cnt:
for index in range(0, 2 * conversation_cnt, 2):
what_i_have_asked = self.__conversation_user(history[index])
what_gpt_answer = {
"role": "model",
"parts": [{"text": history[index + 1]}]
}
messages.append(what_i_have_asked)
messages.append(what_gpt_answer)
return messages

def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
return response.iter_lines()

def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
messages = [
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
# {"role": "user", "parts": [{"text": ""}]},
# {"role": "model", "parts": [{"text": ""}]}
]
self.url_gemini = self.url_gemini.replace(
'%m', llm_kwargs['llm_model']).replace(
'%k', get_conf('GEMINI_API_KEY')
)
header = {'Content-Type': 'application/json'}
if 'vision' not in self.url_gemini: # 不是vision 才处理history
messages.extend(self.__conversation_history(history)) # 处理 history
messages.append(self.__conversation_user(inputs)) # 处理用户对话
payload = {
"contents": messages,
"generationConfig": {
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
"temperature": llm_kwargs.get('temperature', 1),
# "maxOutputTokens": 800,
"topP": llm_kwargs.get('top_p', 0.8),
"topK": 10
}
}
return header, payload


if __name__ == '__main__':
google = GoogleChatInit()
# print(gootle.generate_message_payload('你好呀', {},
# ['123123', '3123123'], ''))
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')

0 comments on commit a7c960d

Please sign in to comment.