-
Notifications
You must be signed in to change notification settings - Fork 25
/
interview-hfinference.py
executable file
·85 lines (73 loc) · 2.65 KB
/
interview-hfinference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
import json
import requests
import os
import time
from jinja2 import Template
import argparse
from pathlib import Path
from prepare import save_interview
parser = argparse.ArgumentParser(description='Interview executor for HuggingFace Inference API')
parser.add_argument('--input', type=str, required=True, help='path to prepare*.ndjson from prepare stage')
parser.add_argument('--model', type=str, default='bigcode/starcoder', help='model to use')
parser.add_argument('--params', type=str, required=True, help='parameter file to use')
parser.add_argument('--templateout', type=str, required=True, help='output template file')
args = parser.parse_args()
headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
API_URL = "https://api-inference.huggingface.co/models/"+args.model
def query(payload):
tries = 0
while tries < 5:
tries += 1
response = requests.request("POST", API_URL, headers=headers, json=payload)
res = {}
try:
res = json.loads(response.content.decode("utf-8"))
except Exception as e:
print('JSON decoder failed:', response.content.decode("utf-8"))
time.sleep(1)
continue
if not isinstance(res, list):
try:
print('Generation error:', res['error'])
except:
print('Something weird went wrong', res)
time.sleep(1)
continue
return res
# Load params and adapt to model format
# see https://huggingface.github.io/text-generation-inference/ GenerateParameters struct
params = json.load(open(args.params))
model_params = {
"temperature": params['temperature'],
"top_k": params['top_k'],
"top_p": params['top_p'],
"max_new_tokens": params['max_new_tokens'],
"repetition_penalty": params['repetition_penalty']
}
# Output template
output_template = Template(open(args.templateout).read())
# Run Interview
interview = [json.loads(line) for line in open(args.input)]
results = []
for challenge in interview:
data = query(
{
"inputs": challenge['prompt'],
"parameters": model_params,
}
)
result = data[0]['generated_text']
result = result.replace(challenge['prompt'], '').replace('<|endoftext|>','')
output = output_template.render(**challenge, Answer=result)
print()
print(output)
print()
result = challenge.copy()
result['answer'] = output
result['params'] = model_params
result['model'] = args.model
result['runtime'] = 'api-inference'
results.append(result)
# Save results
save_interview(args.input, args.templateout, args.params, args.model, results)