-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Llama Guard notebooks #400
base: main
Are you sure you want to change the base?
Changes from 22 commits
9b5faf5
56c5fb4
e06c5c4
f9efe5b
7ff8959
710de4b
779579b
f8f5426
2e46c97
e31de6e
2a5ac81
0e19a1d
9b4cef0
d96ed93
4fb5f7e
aee60fa
bf44153
dd5ebcc
79785e8
baab5fa
caf865d
eed26c9
e83bffd
96748ea
a59da93
7a2748a
982b9a7
ab5680f
777c06d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,14 @@ | |
<!-- markdown-link-check-disable --> | ||
Llama Guard is a language model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard). | ||
|
||
This folder contains an example file to run Llama Guard inference directly. | ||
This folder contains example notebooks on running Llama Guard stand alone and validating Llama Guard performance against a reference dataset. The dataset is not provided, only the format in which it should be to use the scripts out of the box. Additionally, Llama Guard is being used as an optional safety checker when running the regular Llama [inference script](../../inference/local_inference/inference.py). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we callout this can be used to convert ToxicChat dataset using script to run validation on LlamaGuard? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not intended to convert the toxic chat dataset yet, but it's the base for that in the future. |
||
|
||
## Requirements | ||
1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download) | ||
2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation) | ||
3. A GPU with at least 21 GB of free RAM to load both 7B models quantized. | ||
3. A GPU with 8GB or available memory or more. | ||
|
||
## Llama Guard inference script | ||
## Llama Guard inference notebook | ||
For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent. | ||
|
||
|
||
|
@@ -30,28 +30,11 @@ For testing, you can add User or User/Agent interactions into the prompts list a | |
The complete prompt is built with the `build_prompt` function, defined in [prompt_format.py](../../src/llama_recipes/inference/prompt_format.py). The file contains the default Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model. | ||
<!-- markdown-link-check-enable --> | ||
|
||
To run the samples, with all the dependencies installed, execute this command: | ||
To run the samples, execute the [Inference notebook](inference.ipynb). | ||
|
||
`python examples/llama_guard/inference.py` | ||
## Llama Guard validation notebook | ||
|
||
This is the output: | ||
|
||
``` | ||
['<Sample user prompt>'] | ||
> safe | ||
|
||
================================== | ||
|
||
['<Sample user prompt>', '<Sample agent response>'] | ||
> safe | ||
|
||
================================== | ||
|
||
['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>'] | ||
> safe | ||
|
||
================================== | ||
``` | ||
For more details on how to evaluate the performance of the model against a custom dataset, refer to the [Validation](validation.ipynb) notebook. | ||
|
||
## Inference Safety Checker | ||
When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0980720f-4d3c-464c-97c7-7ee9125239bd", | ||
"metadata": {}, | ||
"source": [ | ||
"# Running Llama Guard inference\n", | ||
"\n", | ||
"This notebook is intented to showcase how to run Llama Guard inference on a sample prompt for testing." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c4c946e0-4b84-4d51-80f1-be1a43c75b90", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install --upgrade huggingface_hub" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7b6d0f79-4886-4b86-a7c7-fb42321d8495", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n", | ||
"\n", | ||
"from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY\n", | ||
"from llama_recipes.inference.llama.generation import Llama\n", | ||
"\n", | ||
"from typing import List, Optional, Tuple, Dict\n", | ||
"from enum import Enum\n", | ||
"\n", | ||
"import torch\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "eb0c5970-5e3a-4340-913e-61fd5a5e780a", | ||
"metadata": {}, | ||
"source": [ | ||
"# Defining the main functions\n", | ||
"\n", | ||
"Agent type enum to define what type of inference Llama Guard should be doing, either User or Agent.\n", | ||
"\n", | ||
"The llm_eval function loads the Llama Guard model from Hugging Face. Then iterates over the prompts and generates the results for each token." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "acd9b498-3457-4df4-b829-828e53d5bf5a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class AgentType(Enum):\n", | ||
" AGENT = \"Agent\"\n", | ||
" USER = \"User\"\n", | ||
"\n", | ||
"def llm_eval(prompts, load_in_8bit=True, load_in_4bit = False, logprobs = False) -> Tuple[List[str], Optional[List[List[Tuple[int, float]]]]]:\n", | ||
"\n", | ||
" model_id = \"meta-llama/LlamaGuard-7b\"\n", | ||
" \n", | ||
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | ||
"\n", | ||
" torch_dtype = torch.float32\n", | ||
" if load_in_4bit:\n", | ||
" torch_dtype = torch.bfloat16\n", | ||
"\n", | ||
" bnb_config = BitsAndBytesConfig(\n", | ||
" load_in_8bit=load_in_8bit,\n", | ||
" load_in_4bit=load_in_4bit,\n", | ||
" bnb_4bit_use_double_quant=True,\n", | ||
" bnb_4bit_quant_type=\"nf4\",\n", | ||
" bnb_4bit_compute_dtype=torch_dtype\n", | ||
" )\n", | ||
"\n", | ||
" \n", | ||
" model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map=\"auto\")\n", | ||
"\n", | ||
" results: List[str] = []\n", | ||
" if logprobs:\n", | ||
" result_logprobs: List[List[Tuple[int, float]]] = []\n", | ||
"\n", | ||
" for prompt in prompts:\n", | ||
" formatted_prompt = build_prompt(\n", | ||
" prompt[\"agent_type\"], \n", | ||
" LLAMA_GUARD_CATEGORY, \n", | ||
" create_conversation(prompt[\"prompt\"]))\n", | ||
"\n", | ||
"\n", | ||
" input = tokenizer([formatted_prompt], return_tensors=\"pt\").to(\"cuda\")\n", | ||
" prompt_len = input[\"input_ids\"].shape[-1]\n", | ||
" output = model.generate(**input, max_new_tokens=10, pad_token_id=0, return_dict_in_generate=True, output_scores=logprobs)\n", | ||
" \n", | ||
" if logprobs:\n", | ||
" transition_scores = model.compute_transition_scores(\n", | ||
" output.sequences, output.scores, normalize_logits=True)\n", | ||
"\n", | ||
" generated_tokens = output.sequences[:, prompt_len:]\n", | ||
" \n", | ||
" if logprobs:\n", | ||
" temp_logprobs: List[Tuple[int, float]] = []\n", | ||
" for tok, score in zip(generated_tokens[0], transition_scores[0]):\n", | ||
" temp_logprobs.append((tok.cpu().numpy(), score.cpu().numpy()))\n", | ||
" \n", | ||
" result_logprobs.append(temp_logprobs)\n", | ||
" prompt[\"logprobs\"] = temp_logprobs\n", | ||
" \n", | ||
" result = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) \n", | ||
"\n", | ||
" prompt[\"result\"] = result\n", | ||
" results.append(result)\n", | ||
"\n", | ||
" return (results, result_logprobs if logprobs else None) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "26966de0-a477-4a56-830f-5bd8565bede8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def pytorch_llm_eval(prompts: List[Tuple[List[str], AgentType, str, str, str]], ckpt_dir, logprobs: bool = False):\n", | ||
" # defaults\n", | ||
" temperature = 1\n", | ||
" top_p = 1\n", | ||
" max_seq_len = 4096\n", | ||
" max_gen_len = 32\n", | ||
" max_batch_size = 1\n", | ||
"\n", | ||
" generator = Llama.build(\n", | ||
" ckpt_dir=ckpt_dir,\n", | ||
" tokenizer_path=ckpt_dir + \"/tokenizer.model\",\n", | ||
" max_seq_len=max_seq_len,\n", | ||
" max_batch_size=max_batch_size,\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
" results: List[str] = []\n", | ||
" for prompt in prompts:\n", | ||
" formatted_prompt = build_prompt(\n", | ||
" prompt[\"agent_type\"], \n", | ||
" LLAMA_GUARD_CATEGORY, \n", | ||
" create_conversation(prompt[\"prompt\"]))\n", | ||
"\n", | ||
" result = generator.text_completion(\n", | ||
" [formatted_prompt],\n", | ||
" temperature,\n", | ||
" top_p,\n", | ||
" max_gen_len,\n", | ||
" logprobs\n", | ||
" )\n", | ||
" # getting the first value only, as only a single prompt was sent to the function\n", | ||
" generation_result = result[0][\"generation\"]\n", | ||
" prompt[\"result\"] = generation_result\n", | ||
" if logprobs:\n", | ||
" prompt[\"logprobs\"] = result[0][\"logprobs\"]\n", | ||
"\n", | ||
" results.append(generation_result)\n", | ||
"\n", | ||
" return results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bd2ed6f5-e561-4aeb-9427-ae8bb261a772", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Setting variables used by the Llama classes\n", | ||
"import os\n", | ||
"os.environ[\"RANK\"] = \"0\"\n", | ||
"os.environ[\"WORLD_SIZE\"] = \"1\"\n", | ||
"os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n", | ||
"os.environ[\"MASTER_PORT\"] = \"29500\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b54b4e41-2379-4163-adc6-465e9b7fc57a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def main():\n", | ||
" \n", | ||
" prompts: List[Dict[List[str], AgentType]] = [\n", | ||
" {\n", | ||
" \"prompt\": [\"<Sample user prompt>\"],\n", | ||
" \"agent_type\": AgentType.USER\n", | ||
" },\n", | ||
" {\n", | ||
" \"prompt\": [\"<Sample user prompt>\", \"<Sample agent response>\"],\n", | ||
" \"agent_type\": AgentType.AGENT\n", | ||
" },\n", | ||
" {\n", | ||
" \"prompt\": [\"<Sample user prompt>\", \n", | ||
" \"<Sample agent response>\", \n", | ||
" \"<Sample user reply>\", \n", | ||
" \"<Sample agent response>\"],\n", | ||
" \"agent_type\": AgentType.AGENT\n", | ||
" }\n", | ||
" ]\n", | ||
"\n", | ||
" \n", | ||
" results = llm_eval(prompts, load_in_8bit = False, load_in_4bit = True)\n", | ||
" \n", | ||
" for i, prompt in enumerate(prompts):\n", | ||
" print(prompt['prompt'])\n", | ||
" print(f\"> {results[0][i]}\")\n", | ||
" print(\"\\n==================================\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2345c875-d610-4b8a-b73c-464581575e93", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# used to be able to import this script in another notebook and not run the main function\n", | ||
"if __name__ == '__main__' and '__file__' not in globals():\n", | ||
" from huggingface_hub import login\n", | ||
" login()\n", | ||
" main()\n", | ||
" " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"formats": "ipynb,auto:light" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like inference.py seems to be deleted? Where should this command be run from? Should we provide a
cd
instruction beforepython
command?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the inference.py script I'm referencing here, not the one in the llama_guard directory.