Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Llama Guard notebooks #400

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9b5faf5
first version of the validation script
albertodepaola Feb 9, 2024
56c5fb4
final version with required parameters to execute both user and agent…
albertodepaola Feb 9, 2024
e06c5c4
Adding standard Llama inference methods. Not working fully
albertodepaola Feb 12, 2024
f9efe5b
Adding correct order to standard inference script
albertodepaola Feb 13, 2024
7ff8959
final version of the validation scripts, saving both results and stat…
albertodepaola Feb 13, 2024
710de4b
adding logprobs and average precision calculation
albertodepaola Feb 20, 2024
779579b
Fixing the lack of exp on the log probs. Adding support for plain Lla…
albertodepaola Feb 21, 2024
f8f5426
Modifying the logprobs to fetch only probabilities for unsafe tokens.…
albertodepaola Feb 26, 2024
2e46c97
Adding comment in code
albertodepaola Feb 26, 2024
e31de6e
Fixing inference errro with base python. Removing unnecesay comments
albertodepaola Mar 13, 2024
2a5ac81
Breaking changest in generation.py to import it from the new notebooks.
albertodepaola Mar 14, 2024
0e19a1d
Merge refactored folder structure.
albertodepaola Mar 14, 2024
9b4cef0
Moving files around to conform to the new structure.
albertodepaola Mar 14, 2024
d96ed93
Last version of pytorch inference
albertodepaola Mar 15, 2024
4fb5f7e
Fixing errors in notebooks due to directory changes. Adding descripti…
albertodepaola Mar 15, 2024
aee60fa
removing test file as script is no longer available. Removing inferen…
albertodepaola Mar 15, 2024
bf44153
Removing temporary code to load a subset of lines
albertodepaola Mar 15, 2024
dd5ebcc
Fixing typos
albertodepaola Mar 15, 2024
79785e8
renaming notebooks to lower case
albertodepaola Mar 19, 2024
baab5fa
Adding jupytext synced scripts and jupyter and jupytext dependencies
albertodepaola Mar 19, 2024
caf865d
fixing links in readme
albertodepaola Mar 19, 2024
eed26c9
fixing generic statement
albertodepaola Mar 19, 2024
e83bffd
Merged commit
albertodepaola Apr 29, 2024
96748ea
Adding modifications for llama 3 support initial commit
albertodepaola Apr 29, 2024
a59da93
removing examples from the package
albertodepaola Apr 29, 2024
7a2748a
Merging launch inference script into new llama guard responsible stru…
albertodepaola May 14, 2024
982b9a7
Removing unnecessary files. Adding missed requirements.txt
albertodepaola May 14, 2024
ab5680f
Adding support for Llama Guard 2 on HF. Pytorch is for Llama Guard 1…
albertodepaola May 14, 2024
777c06d
Using HF by default
albertodepaola May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ exclude = [
]

[tool.hatch.build.targets.wheel]
packages = ["src/llama_recipes"]
packages = ["src/llama_recipes", "examples"]

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]
Expand Down
8 changes: 7 additions & 1 deletion recipes/inference/local_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ In the case of a parameter efficient method like lora the base model has to be g
Additionally, a prompt for the model in the form of a text file has to be provided. The prompt file can either be piped through standard input or given as --prompt_file parameter.

**Content Safety**
The inference script also supports safety checks for both user prompt and model outputs. In particular, we use two packages, [AuditNLG](https://github.com/salesforce/AuditNLG/tree/main) and [Azure content safety](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/).
The inference script also supports safety checks for both user prompt and model outputs. In particular, we use [AuditNLG](https://github.com/salesforce/AuditNLG/tree/main) and [Azure content safety](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/). Additionally, we offer Llama Guard with the default categories as an optional content safety checker as well.

**Note**
If using Azure content Safety, please make sure to get the endpoint and API key as described [here](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/) and add them as the following environment variables,`CONTENT_SAFETY_ENDPOINT` and `CONTENT_SAFETY_KEY`.

**Note on Llama Guard**
Use this command for testing with a quantized Llama model, modifying the values accordingly:

`python inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization --enable_llamaguard_content_safety`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like inference.py seems to be deleted? Where should this command be run from? Should we provide a cd instruction before python command?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the inference.py script I'm referencing here, not the one in the llama_guard directory.



Examples:

```bash
Expand Down
29 changes: 6 additions & 23 deletions recipes/responsible_ai/llama_guard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
<!-- markdown-link-check-disable -->
Llama Guard is a language model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard).

This folder contains an example file to run Llama Guard inference directly.
This folder contains example notebooks on running Llama Guard stand alone and validating Llama Guard performance against a reference dataset. The dataset is not provided, only the format in which it should be to use the scripts out of the box. Additionally, Llama Guard is being used as an optional safety checker when running the regular Llama [inference script](../../inference/local_inference/inference.py).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we callout this can be used to convert ToxicChat dataset using script to run validation on LlamaGuard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intended to convert the toxic chat dataset yet, but it's the base for that in the future.


## Requirements
1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
3. A GPU with 8GB or available memory or more.

## Llama Guard inference script
## Llama Guard inference notebook
For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent.


Expand All @@ -30,28 +30,11 @@ For testing, you can add User or User/Agent interactions into the prompts list a
The complete prompt is built with the `build_prompt` function, defined in [prompt_format.py](../../src/llama_recipes/inference/prompt_format.py). The file contains the default Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
<!-- markdown-link-check-enable -->

To run the samples, with all the dependencies installed, execute this command:
To run the samples, execute the [Inference notebook](inference.ipynb).

`python examples/llama_guard/inference.py`
## Llama Guard validation notebook

This is the output:

```
['<Sample user prompt>']
> safe

==================================

['<Sample user prompt>', '<Sample agent response>']
> safe

==================================

['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
> safe

==================================
```
For more details on how to evaluate the performance of the model against a custom dataset, refer to the [Validation](validation.ipynb) notebook.

## Inference Safety Checker
When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
Expand Down
261 changes: 261 additions & 0 deletions recipes/responsible_ai/llama_guard/inference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0980720f-4d3c-464c-97c7-7ee9125239bd",
"metadata": {},
"source": [
"# Running Llama Guard inference\n",
"\n",
"This notebook is intented to showcase how to run Llama Guard inference on a sample prompt for testing."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4c946e0-4b84-4d51-80f1-be1a43c75b90",
"metadata": {},
"outputs": [],
"source": [
"!pip install --upgrade huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b6d0f79-4886-4b86-a7c7-fb42321d8495",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
"\n",
"from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY\n",
"from llama_recipes.inference.llama.generation import Llama\n",
"\n",
"from typing import List, Optional, Tuple, Dict\n",
"from enum import Enum\n",
"\n",
"import torch\n"
]
},
{
"cell_type": "markdown",
"id": "eb0c5970-5e3a-4340-913e-61fd5a5e780a",
"metadata": {},
"source": [
"# Defining the main functions\n",
"\n",
"Agent type enum to define what type of inference Llama Guard should be doing, either User or Agent.\n",
"\n",
"The llm_eval function loads the Llama Guard model from Hugging Face. Then iterates over the prompts and generates the results for each token."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acd9b498-3457-4df4-b829-828e53d5bf5a",
"metadata": {},
"outputs": [],
"source": [
"class AgentType(Enum):\n",
" AGENT = \"Agent\"\n",
" USER = \"User\"\n",
"\n",
"def llm_eval(prompts, load_in_8bit=True, load_in_4bit = False, logprobs = False) -> Tuple[List[str], Optional[List[List[Tuple[int, float]]]]]:\n",
"\n",
" model_id = \"meta-llama/LlamaGuard-7b\"\n",
" \n",
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
" torch_dtype = torch.float32\n",
" if load_in_4bit:\n",
" torch_dtype = torch.bfloat16\n",
"\n",
" bnb_config = BitsAndBytesConfig(\n",
" load_in_8bit=load_in_8bit,\n",
" load_in_4bit=load_in_4bit,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch_dtype\n",
" )\n",
"\n",
" \n",
" model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map=\"auto\")\n",
"\n",
" results: List[str] = []\n",
" if logprobs:\n",
" result_logprobs: List[List[Tuple[int, float]]] = []\n",
"\n",
" for prompt in prompts:\n",
" formatted_prompt = build_prompt(\n",
" prompt[\"agent_type\"], \n",
" LLAMA_GUARD_CATEGORY, \n",
" create_conversation(prompt[\"prompt\"]))\n",
"\n",
"\n",
" input = tokenizer([formatted_prompt], return_tensors=\"pt\").to(\"cuda\")\n",
" prompt_len = input[\"input_ids\"].shape[-1]\n",
" output = model.generate(**input, max_new_tokens=10, pad_token_id=0, return_dict_in_generate=True, output_scores=logprobs)\n",
" \n",
" if logprobs:\n",
" transition_scores = model.compute_transition_scores(\n",
" output.sequences, output.scores, normalize_logits=True)\n",
"\n",
" generated_tokens = output.sequences[:, prompt_len:]\n",
" \n",
" if logprobs:\n",
" temp_logprobs: List[Tuple[int, float]] = []\n",
" for tok, score in zip(generated_tokens[0], transition_scores[0]):\n",
" temp_logprobs.append((tok.cpu().numpy(), score.cpu().numpy()))\n",
" \n",
" result_logprobs.append(temp_logprobs)\n",
" prompt[\"logprobs\"] = temp_logprobs\n",
" \n",
" result = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) \n",
"\n",
" prompt[\"result\"] = result\n",
" results.append(result)\n",
"\n",
" return (results, result_logprobs if logprobs else None) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26966de0-a477-4a56-830f-5bd8565bede8",
"metadata": {},
"outputs": [],
"source": [
"def pytorch_llm_eval(prompts: List[Tuple[List[str], AgentType, str, str, str]], ckpt_dir, logprobs: bool = False):\n",
" # defaults\n",
" temperature = 1\n",
" top_p = 1\n",
" max_seq_len = 4096\n",
" max_gen_len = 32\n",
" max_batch_size = 1\n",
"\n",
" generator = Llama.build(\n",
" ckpt_dir=ckpt_dir,\n",
" tokenizer_path=ckpt_dir + \"/tokenizer.model\",\n",
" max_seq_len=max_seq_len,\n",
" max_batch_size=max_batch_size,\n",
" )\n",
"\n",
"\n",
" results: List[str] = []\n",
" for prompt in prompts:\n",
" formatted_prompt = build_prompt(\n",
" prompt[\"agent_type\"], \n",
" LLAMA_GUARD_CATEGORY, \n",
" create_conversation(prompt[\"prompt\"]))\n",
"\n",
" result = generator.text_completion(\n",
" [formatted_prompt],\n",
" temperature,\n",
" top_p,\n",
" max_gen_len,\n",
" logprobs\n",
" )\n",
" # getting the first value only, as only a single prompt was sent to the function\n",
" generation_result = result[0][\"generation\"]\n",
" prompt[\"result\"] = generation_result\n",
" if logprobs:\n",
" prompt[\"logprobs\"] = result[0][\"logprobs\"]\n",
"\n",
" results.append(generation_result)\n",
"\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd2ed6f5-e561-4aeb-9427-ae8bb261a772",
"metadata": {},
"outputs": [],
"source": [
"# Setting variables used by the Llama classes\n",
"import os\n",
"os.environ[\"RANK\"] = \"0\"\n",
"os.environ[\"WORLD_SIZE\"] = \"1\"\n",
"os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n",
"os.environ[\"MASTER_PORT\"] = \"29500\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b54b4e41-2379-4163-adc6-465e9b7fc57a",
"metadata": {},
"outputs": [],
"source": [
"def main():\n",
" \n",
" prompts: List[Dict[List[str], AgentType]] = [\n",
" {\n",
" \"prompt\": [\"<Sample user prompt>\"],\n",
" \"agent_type\": AgentType.USER\n",
" },\n",
" {\n",
" \"prompt\": [\"<Sample user prompt>\", \"<Sample agent response>\"],\n",
" \"agent_type\": AgentType.AGENT\n",
" },\n",
" {\n",
" \"prompt\": [\"<Sample user prompt>\", \n",
" \"<Sample agent response>\", \n",
" \"<Sample user reply>\", \n",
" \"<Sample agent response>\"],\n",
" \"agent_type\": AgentType.AGENT\n",
" }\n",
" ]\n",
"\n",
" \n",
" results = llm_eval(prompts, load_in_8bit = False, load_in_4bit = True)\n",
" \n",
" for i, prompt in enumerate(prompts):\n",
" print(prompt['prompt'])\n",
" print(f\"> {results[0][i]}\")\n",
" print(\"\\n==================================\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2345c875-d610-4b8a-b73c-464581575e93",
"metadata": {},
"outputs": [],
"source": [
"# used to be able to import this script in another notebook and not run the main function\n",
"if __name__ == '__main__' and '__file__' not in globals():\n",
" from huggingface_hub import login\n",
" login()\n",
" main()\n",
" "
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,auto:light"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}