-
Notifications
You must be signed in to change notification settings - Fork 4
/
ask-llm.go
122 lines (105 loc) · 2.59 KB
/
ask-llm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"net/http"
"os"
"time"
)
var (
LLMAPIBaseURL = os.Getenv("LLM_API_BASE_URL")
LLMAPIKey = os.Getenv("LLM_API_KEY")
LLMChatModel = os.Getenv("LLM_CHAT_MODEL")
LLMDebug = os.Getenv("LLM_DEBUG")
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
Stop []string `json:"stop"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
}
type Choice struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
}
func chat(messages []Message) (string, error) {
url := fmt.Sprintf("%s/chat/completions", LLMAPIBaseURL)
authHeader := ""
if LLMAPIKey != "" {
authHeader = fmt.Sprintf("Bearer %s", LLMAPIKey)
}
requestBody := ChatRequest{
Messages: messages,
Model: LLMChatModel,
Stop: []string{"<|im_end|>", "<|end|>", "<|eot_id|>"},
MaxTokens: 200,
Temperature: 0,
}
jsonBody, err := json.Marshal(requestBody)
if err != nil {
return "", err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
if authHeader != "" {
req.Header.Set("Authorization", authHeader)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP error: %d %s", resp.StatusCode, resp.Status)
}
var data struct {
Choices []Choice `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return "", err
}
answer := data.Choices[0].Message.Content
return answer, nil
}
const SystemPrompt = "Answer the question politely and concisely."
func main() {
fmt.Printf("Using LLM at %s.\n", LLMAPIBaseURL)
fmt.Println("Press Ctrl+D to exit.")
fmt.Println()
messages := []Message{{Role: "system", Content: SystemPrompt}}
scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print(">> ")
scanner.Scan()
question := scanner.Text()
if question == "" {
break
}
messages = append(messages, Message{Role: "user", Content: question})
start := time.Now()
answer, err := chat(messages)
if err != nil {
fmt.Println("Error:", err)
break
}
messages = append(messages, Message{Role: "assistant", Content: answer})
fmt.Println(answer)
elapsed := time.Since(start)
if LLMDebug != "" {
fmt.Printf("[%d ms]\n", elapsed.Milliseconds())
}
fmt.Println()
}
}