-
Notifications
You must be signed in to change notification settings - Fork 40
/
llmcord.py
245 lines (204 loc) · 11.7 KB
/
llmcord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import asyncio
import base64
from datetime import datetime as dt
import logging
from os import environ as env
import requests
import discord
from dotenv import load_dotenv
from litellm import acompletion
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
)
LLM_IS_LOCAL: bool = env["LLM"].startswith("local/")
LLM_SUPPORTS_IMAGES: bool = any(x in env["LLM"] for x in ("claude-3", "gpt-4-turbo", "gpt-4o", "llava", "vision"))
LLM_SUPPORTS_NAMES: bool = any(env["LLM"].startswith(x) for x in ("gpt", "openai/gpt"))
ALLOWED_FILE_TYPES = ("image", "text")
ALLOWED_CHANNEL_TYPES = (discord.ChannelType.text, discord.ChannelType.public_thread, discord.ChannelType.private_thread, discord.ChannelType.private)
ALLOWED_CHANNEL_IDS = tuple(int(id) for id in env["ALLOWED_CHANNEL_IDS"].split(",") if id)
ALLOWED_ROLE_IDS = tuple(int(id) for id in env["ALLOWED_ROLE_IDS"].split(",") if id)
MAX_TEXT = int(env["MAX_TEXT"])
MAX_IMAGES = int(env["MAX_IMAGES"]) if LLM_SUPPORTS_IMAGES else 0
MAX_MESSAGES = int(env["MAX_MESSAGES"])
EMBED_COLOR = {"incomplete": discord.Color.orange(), "complete": discord.Color.green()}
EMBED_MAX_LENGTH = 4096
EDIT_DELAY_SECONDS = 1.3
MAX_MESSAGE_NODES = 100
convert = lambda string: int(string) if string.isdecimal() else (float(string) if string.replace(".", "", 1).isdecimal() else string)
llm_settings = {k.strip(): convert(v.strip()) for k, v in (x.split("=") for x in env["LLM_SETTINGS"].split(",") if x.strip()) if "#" not in k}
if LLM_IS_LOCAL:
llm_settings["base_url"] = env["LOCAL_SERVER_URL"]
if "api_key" not in llm_settings:
llm_settings["api_key"] = "Not used"
env["LLM"] = env["LLM"].replace("local/", "", 1)
if env["DISCORD_CLIENT_ID"]:
print(f"\nBOT INVITE URL:\nhttps://discord.com/api/oauth2/authorize?client_id={env['DISCORD_CLIENT_ID']}&permissions=412317273088&scope=bot\n")
intents = discord.Intents.default()
intents.message_content = True
activity = discord.CustomActivity(name=env["DISCORD_STATUS_MESSAGE"][:128] or "github.com/jakobdylanc/discord-llm-chatbot")
bot = discord.Client(intents=intents, activity=activity)
msg_nodes = {}
msg_locks = {}
last_task_time = None
class MsgNode:
def __init__(self, data, replied_to_msg=None, too_much_text=False, too_many_images=False, has_bad_attachments=False, fetch_next_failed=False):
self.data = data
self.replied_to_msg = replied_to_msg
self.too_much_text: bool = too_much_text
self.too_many_images: bool = too_many_images
self.has_bad_attachments: bool = has_bad_attachments
self.fetch_next_failed: bool = fetch_next_failed
def get_system_prompt():
system_prompt_extras = [f"Today's date: {dt.now().strftime('%B %d %Y')}"]
if LLM_SUPPORTS_NAMES:
system_prompt_extras += ["User's names are their Discord IDs and should be typed as '<@ID>'."]
return [
{
"role": "system",
"content": "\n".join([env["LLM_SYSTEM_PROMPT"]] + system_prompt_extras),
}
]
@bot.event
async def on_message(new_msg):
global msg_nodes, msg_locks, last_task_time
# Filter out unwanted messages
if (
new_msg.channel.type not in ALLOWED_CHANNEL_TYPES
or (new_msg.channel.type != discord.ChannelType.private and bot.user not in new_msg.mentions)
or (ALLOWED_CHANNEL_IDS and not any(id in ALLOWED_CHANNEL_IDS for id in (new_msg.channel.id, getattr(new_msg.channel, "parent_id", None))))
or (ALLOWED_ROLE_IDS and (new_msg.channel.type == discord.ChannelType.private or not any(role.id in ALLOWED_ROLE_IDS for role in new_msg.author.roles)))
or new_msg.author.bot
):
return
# Build message reply chain and set user warnings
reply_chain = []
user_warnings = set()
curr_msg = new_msg
while curr_msg and len(reply_chain) < MAX_MESSAGES:
async with msg_locks.setdefault(curr_msg.id, asyncio.Lock()):
if curr_msg.id not in msg_nodes:
good_attachments = {type: [att for att in curr_msg.attachments if att.content_type and type in att.content_type] for type in ALLOWED_FILE_TYPES}
text = "\n".join(
([curr_msg.content] if curr_msg.content else [])
+ [embed.description for embed in curr_msg.embeds if embed.description]
+ [requests.get(att.url).text for att in good_attachments["text"]]
)
if curr_msg.content.startswith(bot.user.mention):
text = text.replace(bot.user.mention, "", 1).lstrip()
if LLM_SUPPORTS_IMAGES and good_attachments["image"][:MAX_IMAGES]:
content = ([{"type": "text", "text": text[:MAX_TEXT]}] if text[:MAX_TEXT] else []) + [
{
"type": "image_url",
"image_url": {"url": f"data:{att.content_type};base64,{base64.b64encode(requests.get(att.url).content).decode('utf-8')}"},
}
for att in good_attachments["image"][:MAX_IMAGES]
]
else:
content = text[:MAX_TEXT] or "."
data = {
"content": content,
"role": "assistant" if curr_msg.author == bot.user else "user",
}
if LLM_SUPPORTS_NAMES:
data["name"] = str(curr_msg.author.id)
msg_nodes[curr_msg.id] = MsgNode(
data=data,
too_much_text=len(text) > MAX_TEXT,
too_many_images=len(good_attachments["image"]) > MAX_IMAGES,
has_bad_attachments=len(curr_msg.attachments) > sum(len(att_list) for att_list in good_attachments.values()),
)
try:
if (
not curr_msg.reference
and curr_msg.channel.type != discord.ChannelType.private
and bot.user.mention not in curr_msg.content
and (prev_msg_in_channel := ([m async for m in curr_msg.channel.history(before=curr_msg, limit=1)] or [None])[0])
and any(prev_msg_in_channel.type == type for type in (discord.MessageType.default, discord.MessageType.reply))
and prev_msg_in_channel.author == curr_msg.author
):
msg_nodes[curr_msg.id].replied_to_msg = prev_msg_in_channel
else:
next_is_thread_parent: bool = not curr_msg.reference and curr_msg.channel.type == discord.ChannelType.public_thread
if next_msg_id := curr_msg.channel.id if next_is_thread_parent else getattr(curr_msg.reference, "message_id", None):
while msg_locks.setdefault(next_msg_id, asyncio.Lock()).locked():
await asyncio.sleep(0)
msg_nodes[curr_msg.id].replied_to_msg = (
(curr_msg.channel.starter_message or await curr_msg.channel.parent.fetch_message(next_msg_id))
if next_is_thread_parent
else (r if isinstance(r := curr_msg.reference.resolved, discord.Message) else await curr_msg.channel.fetch_message(next_msg_id))
)
except (discord.NotFound, discord.HTTPException, AttributeError):
logging.exception("Error fetching next message in the chain")
msg_nodes[curr_msg.id].fetch_next_failed = True
curr_node = msg_nodes[curr_msg.id]
reply_chain += [curr_node.data]
if curr_node.too_much_text:
user_warnings.add(f"⚠️ Max {MAX_TEXT:,} characters per message")
if curr_node.too_many_images:
user_warnings.add(f"⚠️ Max {MAX_IMAGES} image{'' if MAX_IMAGES == 1 else 's'} per message" if MAX_IMAGES > 0 else "⚠️ Can't see images")
if curr_node.has_bad_attachments:
user_warnings.add("⚠️ Unsupported attachments")
if curr_node.fetch_next_failed or (curr_node.replied_to_msg and len(reply_chain) == MAX_MESSAGES):
user_warnings.add(f"⚠️ Only using last{'' if (count := len(reply_chain)) == 1 else f' {count}'} message{'' if count == 1 else 's'}")
curr_msg = curr_node.replied_to_msg
logging.info(f"Message received (user ID: {new_msg.author.id}, attachments: {len(new_msg.attachments)}, reply chain length: {len(reply_chain)}):\n{new_msg.content}")
# Generate and send response message(s) (can be multiple if response is long)
response_msgs = []
response_contents = []
prev_chunk = None
edit_task = None
kwargs = dict(model=env["LLM"], messages=(get_system_prompt() + reply_chain[::-1]), stream=True) | llm_settings
try:
async with new_msg.channel.typing():
async for curr_chunk in await acompletion(**kwargs):
if prev_chunk:
prev_content = prev_chunk.choices[0].delta.content or ""
curr_content = curr_chunk.choices[0].delta.content or ""
if not response_msgs or len(response_contents[-1] + prev_content) > EMBED_MAX_LENGTH:
reply_to_msg = new_msg if not response_msgs else response_msgs[-1]
embed = discord.Embed(description="⏳", color=EMBED_COLOR["incomplete"])
for warning in sorted(user_warnings):
embed.add_field(name=warning, value="", inline=False)
response_msgs += [
await reply_to_msg.reply(
embed=embed,
silent=True,
)
]
await msg_locks.setdefault(response_msgs[-1].id, asyncio.Lock()).acquire()
last_task_time = dt.now().timestamp()
response_contents += [""]
response_contents[-1] += prev_content
is_final_edit: bool = curr_chunk.choices[0].finish_reason != None or len(response_contents[-1] + curr_content) > EMBED_MAX_LENGTH
if is_final_edit or (not edit_task or edit_task.done()) and dt.now().timestamp() - last_task_time >= EDIT_DELAY_SECONDS:
while edit_task and not edit_task.done():
await asyncio.sleep(0)
if response_contents[-1].strip():
embed.description = response_contents[-1]
embed.color = EMBED_COLOR["complete"] if is_final_edit else EMBED_COLOR["incomplete"]
edit_task = asyncio.create_task(response_msgs[-1].edit(embed=embed))
last_task_time = dt.now().timestamp()
prev_chunk = curr_chunk
except:
logging.exception("Error while streaming response")
# Create MsgNodes for response messages
for msg in response_msgs:
data = {
"content": "".join(response_contents) or ".",
"role": "assistant",
}
if LLM_SUPPORTS_NAMES:
data["name"] = str(bot.user.id)
msg_nodes[msg.id] = MsgNode(data=data, replied_to_msg=new_msg)
msg_locks[msg.id].release()
# Delete MsgNodes for oldest messages (lowest IDs)
if (num_nodes := len(msg_nodes)) > MAX_MESSAGE_NODES:
for msg_id in sorted(msg_nodes.keys())[: num_nodes - MAX_MESSAGE_NODES]:
async with msg_locks.setdefault(msg_id, asyncio.Lock()):
msg_nodes.pop(msg_id, None)
msg_locks.pop(msg_id, None)
async def main():
await bot.start(env["DISCORD_BOT_TOKEN"])
asyncio.run(main())