Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Llama Guard notebooks #400

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9b5faf5
first version of the validation script
albertodepaola Feb 9, 2024
56c5fb4
final version with required parameters to execute both user and agent…
albertodepaola Feb 9, 2024
e06c5c4
Adding standard Llama inference methods. Not working fully
albertodepaola Feb 12, 2024
f9efe5b
Adding correct order to standard inference script
albertodepaola Feb 13, 2024
7ff8959
final version of the validation scripts, saving both results and stat…
albertodepaola Feb 13, 2024
710de4b
adding logprobs and average precision calculation
albertodepaola Feb 20, 2024
779579b
Fixing the lack of exp on the log probs. Adding support for plain Lla…
albertodepaola Feb 21, 2024
f8f5426
Modifying the logprobs to fetch only probabilities for unsafe tokens.…
albertodepaola Feb 26, 2024
2e46c97
Adding comment in code
albertodepaola Feb 26, 2024
e31de6e
Fixing inference errro with base python. Removing unnecesay comments
albertodepaola Mar 13, 2024
2a5ac81
Breaking changest in generation.py to import it from the new notebooks.
albertodepaola Mar 14, 2024
0e19a1d
Merge refactored folder structure.
albertodepaola Mar 14, 2024
9b4cef0
Moving files around to conform to the new structure.
albertodepaola Mar 14, 2024
d96ed93
Last version of pytorch inference
albertodepaola Mar 15, 2024
4fb5f7e
Fixing errors in notebooks due to directory changes. Adding descripti…
albertodepaola Mar 15, 2024
aee60fa
removing test file as script is no longer available. Removing inferen…
albertodepaola Mar 15, 2024
bf44153
Removing temporary code to load a subset of lines
albertodepaola Mar 15, 2024
dd5ebcc
Fixing typos
albertodepaola Mar 15, 2024
79785e8
renaming notebooks to lower case
albertodepaola Mar 19, 2024
baab5fa
Adding jupytext synced scripts and jupyter and jupytext dependencies
albertodepaola Mar 19, 2024
caf865d
fixing links in readme
albertodepaola Mar 19, 2024
eed26c9
fixing generic statement
albertodepaola Mar 19, 2024
e83bffd
Merged commit
albertodepaola Apr 29, 2024
96748ea
Adding modifications for llama 3 support initial commit
albertodepaola Apr 29, 2024
a59da93
removing examples from the package
albertodepaola Apr 29, 2024
7a2748a
Merging launch inference script into new llama guard responsible stru…
albertodepaola May 14, 2024
982b9a7
Removing unnecessary files. Adding missed requirements.txt
albertodepaola May 14, 2024
ab5680f
Adding support for Llama Guard 2 on HF. Pytorch is for Llama Guard 1…
albertodepaola May 14, 2024
777c06d
Using HF by default
albertodepaola May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ exclude = [
]

[tool.hatch.build.targets.wheel]
packages = ["src/llama_recipes"]
packages = ["src/llama_recipes", "examples"]

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]
Expand Down
8 changes: 7 additions & 1 deletion recipes/inference/local_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ In the case of a parameter efficient method like lora the base model has to be g
Additionally, a prompt for the model in the form of a text file has to be provided. The prompt file can either be piped through standard input or given as --prompt_file parameter.

**Content Safety**
The inference script also supports safety checks for both user prompt and model outputs. In particular, we use two packages, [AuditNLG](https://github.com/salesforce/AuditNLG/tree/main) and [Azure content safety](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/).
The inference script also supports safety checks for both user prompt and model outputs. In particular, we use [AuditNLG](https://github.com/salesforce/AuditNLG/tree/main) and [Azure content safety](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/). Additionally, we offer Llama Guard with the default categories as an optional content safety checker as well.

**Note**
If using Azure content Safety, please make sure to get the endpoint and API key as described [here](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/) and add them as the following environment variables,`CONTENT_SAFETY_ENDPOINT` and `CONTENT_SAFETY_KEY`.

**Note on Llama Guard**
Use this command for testing with a quantized Llama model, modifying the values accordingly:

`python inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization --enable_llamaguard_content_safety`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like inference.py seems to be deleted? Where should this command be run from? Should we provide a cd instruction before python command?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the inference.py script I'm referencing here, not the one in the llama_guard directory.



Examples:

```bash
Expand Down
344 changes: 344 additions & 0 deletions recipes/responsible_ai/llama_guard/Inference.ipynb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the new naming convention for the files? Is it lowercase or capitalcase?

Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0980720f-4d3c-464c-97c7-7ee9125239bd",
"metadata": {},
"source": [
"# Running Llama Guard inference\n",
"\n",
"This notebook is intented to showcase how to run Llama Guard inference on a sample prompt for testing."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c4c946e0-4b84-4d51-80f1-be1a43c75b90",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: huggingface_hub in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (0.21.4)\n",
"Requirement already satisfied: filelock in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (3.13.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (2023.10.0)\n",
"Requirement already satisfied: requests in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (4.62.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (4.10.0)\n",
"Requirement already satisfied: packaging>=20.9 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (23.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (2023.11.17)\n"
]
}
],
"source": [
"!pip install --upgrade huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7b6d0f79-4886-4b86-a7c7-fb42321d8495",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY\n",
"from llama_recipes.inference.llama.generation import Llama\n",
"\n",
"from typing import List, Optional, Tuple, Dict\n",
"from enum import Enum\n"
]
},
{
"cell_type": "markdown",
"id": "eb0c5970-5e3a-4340-913e-61fd5a5e780a",
"metadata": {},
"source": [
"# Defining the main functions\n",
"\n",
"Agent type enum to define what type of inference Llama Guard should be doing, either User or Agent.\n",
"\n",
"The llm_eval function loads the Llama Guard model from Hugging Face. Then iterates over the prompts and generates the results for each token."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "acd9b498-3457-4df4-b829-828e53d5bf5a",
"metadata": {},
"outputs": [],
"source": [
"class AgentType(Enum):\n",
" AGENT = \"Agent\"\n",
" USER = \"User\"\n",
"\n",
"def llm_eval(prompts, load_in_8bit=True, logprobs = False) -> Tuple[List[str], Optional[List[List[Tuple[int, float]]]]]:\n",
"\n",
" model_id = \"meta-llama/LlamaGuard-7b\"\n",
" \n",
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
" model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=load_in_8bit, device_map=\"auto\")\n",
"\n",
" results: List[str] = []\n",
" if logprobs:\n",
" result_logprobs: List[List[Tuple[int, float]]] = []\n",
"\n",
" for prompt in prompts:\n",
" formatted_prompt = build_prompt(\n",
" prompt[\"agent_type\"], \n",
" LLAMA_GUARD_CATEGORY, \n",
" create_conversation(prompt[\"prompt\"]))\n",
"\n",
"\n",
" input = tokenizer([formatted_prompt], return_tensors=\"pt\").to(\"cuda\")\n",
" prompt_len = input[\"input_ids\"].shape[-1]\n",
" output = model.generate(**input, max_new_tokens=10, pad_token_id=0, return_dict_in_generate=True, output_scores=logprobs)\n",
" \n",
" if logprobs:\n",
" transition_scores = model.compute_transition_scores(\n",
" output.sequences, output.scores, normalize_logits=True)\n",
"\n",
" generated_tokens = output.sequences[:, prompt_len:]\n",
" \n",
" if logprobs:\n",
" temp_logprobs: List[Tuple[int, float]] = []\n",
" for tok, score in zip(generated_tokens[0], transition_scores[0]):\n",
" temp_logprobs.append((tok.cpu().numpy(), score.cpu().numpy()))\n",
" \n",
" result_logprobs.append(temp_logprobs)\n",
" prompt[\"logprobs\"] = temp_logprobs\n",
" \n",
" result = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) \n",
"\n",
" prompt[\"result\"] = result\n",
" results.append(result)\n",
"\n",
" return (results, result_logprobs if logprobs else None) "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "26966de0-a477-4a56-830f-5bd8565bede8",
"metadata": {},
"outputs": [],
"source": [
"def pytorch_llm_eval(prompts: List[Tuple[List[str], AgentType, str, str, str]], ckpt_dir, logprobs: bool = False):\n",
" # defaults\n",
" temperature = 1\n",
" top_p = 1\n",
" max_seq_len = 4096\n",
" max_gen_len = 32\n",
" max_batch_size = 1\n",
"\n",
" generator = Llama.build(\n",
" ckpt_dir=ckpt_dir,\n",
" tokenizer_path=ckpt_dir + \"/tokenizer.model\",\n",
" max_seq_len=max_seq_len,\n",
" max_batch_size=max_batch_size,\n",
" )\n",
"\n",
"\n",
" results: List[str] = []\n",
" for prompt in prompts:\n",
" formatted_prompt = build_prompt(\n",
" prompt[\"agent_type\"], \n",
" LLAMA_GUARD_CATEGORY, \n",
" create_conversation(prompt[\"prompt\"]))\n",
"\n",
" result = generator.text_completion(\n",
" [formatted_prompt],\n",
" temperature,\n",
" top_p,\n",
" max_gen_len,\n",
" logprobs\n",
" )\n",
" # getting the first value only, as only a single prompt was sent to the function\n",
" generation_result = result[0][\"generation\"]\n",
" prompt[\"result\"] = generation_result\n",
" if logprobs:\n",
" prompt[\"logprobs\"] = result[0][\"logprobs\"]\n",
"\n",
" results.append(generation_result)\n",
"\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bd2ed6f5-e561-4aeb-9427-ae8bb261a772",
"metadata": {},
"outputs": [],
"source": [
"# Setting variables used by the Llama classes\n",
"import os\n",
"os.environ[\"RANK\"] = \"0\"\n",
"os.environ[\"WORLD_SIZE\"] = \"1\"\n",
"os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n",
"os.environ[\"MASTER_PORT\"] = \"29500\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b54b4e41-2379-4163-adc6-465e9b7fc57a",
"metadata": {},
"outputs": [],
"source": [
"def main():\n",
" \n",
" prompts: List[Dict[List[str], AgentType]] = [\n",
" {\n",
" \"prompt\": [\"<Sample user prompt>\"],\n",
" \"agent_type\": AgentType.USER\n",
" },\n",
" {\n",
" \"prompt\": [\"<Sample user prompt>\", \"<Sample agent response>\"],\n",
" \"agent_type\": AgentType.AGENT\n",
" },\n",
" {\n",
" \"prompt\": [\"<Sample user prompt>\", \n",
" \"<Sample agent response>\", \n",
" \"<Sample user reply>\", \n",
" \"<Sample agent response>\"],\n",
" \"agent_type\": AgentType.AGENT\n",
" }\n",
" ]\n",
"\n",
" \n",
" results = llm_eval(prompts)\n",
" \n",
" for i, prompt in enumerate(prompts):\n",
" print(prompt['prompt'])\n",
" print(f\"> {results[0][i]}\")\n",
" print(\"\\n==================================\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2345c875-d610-4b8a-b73c-464581575e93",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "922f8a8be5bf4ce99b114e83861b08a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.004128694534301758,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": 26,
"postfix": null,
"prefix": "Loading checkpoint shards",
"rate": null,
"total": 3,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "415fbebfe34f4ad8b3cc30f2a8c97e29",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/llama-recipes/lib/python3.10/site-packages/transformers/generation/utils.py:1518: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<Sample user prompt>']\n",
"> safe\n",
"\n",
"==================================\n",
"\n",
"['<Sample user prompt>', '<Sample agent response>']\n",
"> safe\n",
"\n",
"==================================\n",
"\n",
"['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']\n",
"> safe\n",
"\n",
"==================================\n",
"\n"
]
}
],
"source": [
"# used to be able to import this script in another notebook and not run the main function\n",
"if __name__ == '__main__' and '__file__' not in globals():\n",
" from huggingface_hub import login\n",
" login()\n",
" main()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a071f81-7e2c-4c87-86f8-abf7dad6ac89",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}