Skip to content

Commit

Permalink
fix generate_qa function
Browse files Browse the repository at this point in the history
  • Loading branch information
wukaixingxp committed May 7, 2024
1 parent b07cbad commit 6204d5a
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 31 deletions.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def load_config(config_path: str = "./config.yaml"):
config["api_key"] = os.environ["OCTOAI_API_TOKEN"]
except KeyError:
print("API token did not found, please set the OCTOAI_API_TOKEN environment variable if using OctoAI, otherwise set api_key to default EMPTY")
# local Vllm endpoint did not need API key, so set the API key to "EMPTY" if not found
# local Vllm endpoint did not need API key, so set the API key to "EMPTY" if OCTOAI_API_TOKEN not found
config["api_key"] = "EMPTY"
return config
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ question_prompt_template: >
read it and generate question and answer pairs
that are most likely be asked by a use of llama that just want to start,
please make sure you follow those rules,
1. Generate only {total_questions} question answer pairs.
1. Generate only {num_questions} question answer pairs.
2. Generate in {language}.
3. The questions can be answered based *solely* on the given passage.
4. Avoid asking questions with similar meaning.
Expand All @@ -27,4 +27,4 @@ data_dir: "./data"

language: "English"

total_questions: 1000
num_questions: 2
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import json
from config import load_config
from generator_utils import generate_question_batches, parse_qa_to_json, get_model_name
from generator_utils import generate_question_batches, parse_qa_to_json
from itertools import chain
import logging
import aiofiles # Ensure aiofiles is installed for async file operations
Expand All @@ -21,7 +21,10 @@
rate_limit_threshold = 2000
allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
request_limiter = asyncio.Semaphore(allowed_concurrent_requests)

# Since OctoAI has different naming for llama models, create this mapping to get huggingface offical model name given OctoAI names.
MODEL_NAME_MAPPING={"meta-llama-3-70b-instruct":"meta-llama/Meta-Llama-3-70B-Instruct",
"meta-llama-3-8b-instruct":"meta-llama/Meta-Llama-3-8B-Instruct","llama-2-7b-chat":"meta-llama/Llama-2-7b-chat-hf"
,"llama-2-70b-chat":"meta-llama/Llama-2-70b-chat-hf"}
class ChatService(ABC):
@abstractmethod
async def execute_chat_request_async(self, api_context: dict, chat_request):
Expand Down Expand Up @@ -57,7 +60,7 @@ async def execute_chat_request_async(self, api_context: dict, chat_request):
async with request_limiter:
try:
event_loop = asyncio.get_running_loop()
model_name = get_model_name(api_context['model'])
model_name = MODEL_NAME_MAPPING[api_context['model']]
client = OpenAI(api_key=api_context['api_key'], base_url="http://localhost:"+ str(api_context['endpoint'])+"/v1")
api_chat_call = partial(
client.chat.completions.create,
Expand All @@ -68,7 +71,8 @@ async def execute_chat_request_async(self, api_context: dict, chat_request):
response = await event_loop.run_in_executor(None, api_chat_call)
assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
assistant_response_json = parse_qa_to_json(assistant_response)
assert(len(assistant_response_json)!=0)
if len(assistant_response_json)==0:
logging.error("No question/answer pairs generated. Please check the input context or model configuration.")
return assistant_response_json
except Exception as error:
logging.error(f"Error during chat request execution: {error}",exc_info=True)
Expand Down Expand Up @@ -103,8 +107,8 @@ def parse_arguments():
parser.add_argument(
"-t", "--total_questions",
type=int,
default=10,
help="Specify the number of question/answer pairs to generate."
default=100,
help="Specify the total number of question/answer pairs to generate."
)
parser.add_argument(
"-m", "--model",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,7 @@
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Since OctoAI has different naming for llama models, get the huggingface offical model name using OctoAI names.
def get_model_name(model):
if model == "meta-llama-3-70b-instruct":
return "meta-llama/Meta-Llama-3-70B-Instruct"
elif model == "meta-llama-3-8b-instruct":
return "meta-llama/Meta-Llama-3-8B-Instruct"
elif model == "llama-2-7b-chat":
return "meta-llama/Llama-2-7b-chat-hf"
else:
return "meta-llama/Llama-2-70b-chat-hf"

def read_text_file(file_path):
try:
with open(file_path, 'r') as f:
Expand Down Expand Up @@ -88,8 +79,13 @@ def read_file_content(context):
if len(text) == 0:
logging.error(f"Error reading files, text is empty")
return ' '.join(file_strings)


# clean the text by removing all parts that did not contain any alphanumeric characters
def clean(s):
result = []
for item in s.split('"'):
if any(c.isalnum() for c in item):
result.append(item)
return " ".join(result)

def parse_qa_to_json(response_string):
split_lines = response_string.split("\n")
Expand All @@ -109,21 +105,21 @@ def parse_qa_to_json(response_string):
end = i
# found Question means we have reached the end of the question, so add it to qa_list
elif '"Question":' in line:
question = " ".join(" ".join(split_lines[start:end]).split('"Question":')[1].split('"')[1:-1])
answer = " ".join(" ".join(split_lines[end:i]).split('"Answer":')[1].split('"')[1:-1])
question = " ".join(split_lines[start:end]).split('"Question":')[1]
answer = " ".join(split_lines[end:i]).split('"Answer":')[1]
start,end = i,None
qa_set.add((question, answer))
qa_set.add((clean(question), clean(answer)))
# adding last question back to qa_list
if start and end:
question = " ".join(" ".join(split_lines[start:end]).split('"Question":')[1].split('"')[1:-1])
answer = " ".join(" ".join(split_lines[end:i]).split('"Answer":')[1].split('"')[1:-1])
qa_set.add((question, answer))
if start and end:
question = " ".join(split_lines[start:end]).split('"Question":')[1]
answer = " ".join(split_lines[end:]).split('"Answer":')[1]
qa_set.add((clean(question), clean(answer)))
qa_list = [{"question": q, "answer":a} for q,a in qa_set]
return json.dumps(qa_list, indent=4)


async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, total_questions: int) -> dict:
prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
if not result:
Expand All @@ -142,7 +138,8 @@ async def generate_question_batches(chat_service, api_context: dict):

total_questions = api_context["total_questions"]
batches_count = len(document_batches)
base_questions_per_batch = total_questions // batches_count
# each batch should have at least 1 question
base_questions_per_batch = max(total_questions // batches_count,1)
extra_questions = total_questions % batches_count

print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
Expand Down

0 comments on commit 6204d5a

Please sign in to comment.