-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Llama Guard notebooks #400
Draft
albertodepaola
wants to merge
29
commits into
meta-llama:main
Choose a base branch
from
albertodepaola:feat/lama_guard_validation
base: main
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
9b5faf5
first version of the validation script
albertodepaola 56c5fb4
final version with required parameters to execute both user and agent…
albertodepaola e06c5c4
Adding standard Llama inference methods. Not working fully
albertodepaola f9efe5b
Adding correct order to standard inference script
albertodepaola 7ff8959
final version of the validation scripts, saving both results and stat…
albertodepaola 710de4b
adding logprobs and average precision calculation
albertodepaola 779579b
Fixing the lack of exp on the log probs. Adding support for plain Lla…
albertodepaola f8f5426
Modifying the logprobs to fetch only probabilities for unsafe tokens.…
albertodepaola 2e46c97
Adding comment in code
albertodepaola e31de6e
Fixing inference errro with base python. Removing unnecesay comments
albertodepaola 2a5ac81
Breaking changest in generation.py to import it from the new notebooks.
albertodepaola 0e19a1d
Merge refactored folder structure.
albertodepaola 9b4cef0
Moving files around to conform to the new structure.
albertodepaola d96ed93
Last version of pytorch inference
albertodepaola 4fb5f7e
Fixing errors in notebooks due to directory changes. Adding descripti…
albertodepaola aee60fa
removing test file as script is no longer available. Removing inferen…
albertodepaola bf44153
Removing temporary code to load a subset of lines
albertodepaola dd5ebcc
Fixing typos
albertodepaola 79785e8
renaming notebooks to lower case
albertodepaola baab5fa
Adding jupytext synced scripts and jupyter and jupytext dependencies
albertodepaola caf865d
fixing links in readme
albertodepaola eed26c9
fixing generic statement
albertodepaola e83bffd
Merged commit
albertodepaola 96748ea
Adding modifications for llama 3 support initial commit
albertodepaola a59da93
removing examples from the package
albertodepaola 7a2748a
Merging launch inference script into new llama guard responsible stru…
albertodepaola 982b9a7
Removing unnecessary files. Adding missed requirements.txt
albertodepaola ab5680f
Adding support for Llama Guard 2 on HF. Pytorch is for Llama Guard 1…
albertodepaola 777c06d
Using HF by default
albertodepaola File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the new naming convention for the files? Is it lowercase or capitalcase? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,344 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0980720f-4d3c-464c-97c7-7ee9125239bd", | ||
"metadata": {}, | ||
"source": [ | ||
"# Running Llama Guard inference\n", | ||
"\n", | ||
"This notebook is intented to showcase how to run Llama Guard inference on a sample prompt for testing." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "c4c946e0-4b84-4d51-80f1-be1a43c75b90", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Requirement already satisfied: huggingface_hub in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (0.21.4)\n", | ||
"Requirement already satisfied: filelock in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (3.13.1)\n", | ||
"Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (2023.10.0)\n", | ||
"Requirement already satisfied: requests in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (2.31.0)\n", | ||
"Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (4.62.3)\n", | ||
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n", | ||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (4.10.0)\n", | ||
"Requirement already satisfied: packaging>=20.9 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from huggingface_hub) (23.2)\n", | ||
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (3.3.2)\n", | ||
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (3.6)\n", | ||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (2.1.0)\n", | ||
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/llama-recipes/lib/python3.10/site-packages (from requests->huggingface_hub) (2023.11.17)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"!pip install --upgrade huggingface_hub" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "7b6d0f79-4886-4b86-a7c7-fb42321d8495", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | ||
"\n", | ||
"from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY\n", | ||
"from llama_recipes.inference.llama.generation import Llama\n", | ||
"\n", | ||
"from typing import List, Optional, Tuple, Dict\n", | ||
"from enum import Enum\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "eb0c5970-5e3a-4340-913e-61fd5a5e780a", | ||
"metadata": {}, | ||
"source": [ | ||
"# Defining the main functions\n", | ||
"\n", | ||
"Agent type enum to define what type of inference Llama Guard should be doing, either User or Agent.\n", | ||
"\n", | ||
"The llm_eval function loads the Llama Guard model from Hugging Face. Then iterates over the prompts and generates the results for each token." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "acd9b498-3457-4df4-b829-828e53d5bf5a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class AgentType(Enum):\n", | ||
" AGENT = \"Agent\"\n", | ||
" USER = \"User\"\n", | ||
"\n", | ||
"def llm_eval(prompts, load_in_8bit=True, logprobs = False) -> Tuple[List[str], Optional[List[List[Tuple[int, float]]]]]:\n", | ||
"\n", | ||
" model_id = \"meta-llama/LlamaGuard-7b\"\n", | ||
" \n", | ||
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | ||
" model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=load_in_8bit, device_map=\"auto\")\n", | ||
"\n", | ||
" results: List[str] = []\n", | ||
" if logprobs:\n", | ||
" result_logprobs: List[List[Tuple[int, float]]] = []\n", | ||
"\n", | ||
" for prompt in prompts:\n", | ||
" formatted_prompt = build_prompt(\n", | ||
" prompt[\"agent_type\"], \n", | ||
" LLAMA_GUARD_CATEGORY, \n", | ||
" create_conversation(prompt[\"prompt\"]))\n", | ||
"\n", | ||
"\n", | ||
" input = tokenizer([formatted_prompt], return_tensors=\"pt\").to(\"cuda\")\n", | ||
" prompt_len = input[\"input_ids\"].shape[-1]\n", | ||
" output = model.generate(**input, max_new_tokens=10, pad_token_id=0, return_dict_in_generate=True, output_scores=logprobs)\n", | ||
" \n", | ||
" if logprobs:\n", | ||
" transition_scores = model.compute_transition_scores(\n", | ||
" output.sequences, output.scores, normalize_logits=True)\n", | ||
"\n", | ||
" generated_tokens = output.sequences[:, prompt_len:]\n", | ||
" \n", | ||
" if logprobs:\n", | ||
" temp_logprobs: List[Tuple[int, float]] = []\n", | ||
" for tok, score in zip(generated_tokens[0], transition_scores[0]):\n", | ||
" temp_logprobs.append((tok.cpu().numpy(), score.cpu().numpy()))\n", | ||
" \n", | ||
" result_logprobs.append(temp_logprobs)\n", | ||
" prompt[\"logprobs\"] = temp_logprobs\n", | ||
" \n", | ||
" result = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) \n", | ||
"\n", | ||
" prompt[\"result\"] = result\n", | ||
" results.append(result)\n", | ||
"\n", | ||
" return (results, result_logprobs if logprobs else None) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "26966de0-a477-4a56-830f-5bd8565bede8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def pytorch_llm_eval(prompts: List[Tuple[List[str], AgentType, str, str, str]], ckpt_dir, logprobs: bool = False):\n", | ||
" # defaults\n", | ||
" temperature = 1\n", | ||
" top_p = 1\n", | ||
" max_seq_len = 4096\n", | ||
" max_gen_len = 32\n", | ||
" max_batch_size = 1\n", | ||
"\n", | ||
" generator = Llama.build(\n", | ||
" ckpt_dir=ckpt_dir,\n", | ||
" tokenizer_path=ckpt_dir + \"/tokenizer.model\",\n", | ||
" max_seq_len=max_seq_len,\n", | ||
" max_batch_size=max_batch_size,\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
" results: List[str] = []\n", | ||
" for prompt in prompts:\n", | ||
" formatted_prompt = build_prompt(\n", | ||
" prompt[\"agent_type\"], \n", | ||
" LLAMA_GUARD_CATEGORY, \n", | ||
" create_conversation(prompt[\"prompt\"]))\n", | ||
"\n", | ||
" result = generator.text_completion(\n", | ||
" [formatted_prompt],\n", | ||
" temperature,\n", | ||
" top_p,\n", | ||
" max_gen_len,\n", | ||
" logprobs\n", | ||
" )\n", | ||
" # getting the first value only, as only a single prompt was sent to the function\n", | ||
" generation_result = result[0][\"generation\"]\n", | ||
" prompt[\"result\"] = generation_result\n", | ||
" if logprobs:\n", | ||
" prompt[\"logprobs\"] = result[0][\"logprobs\"]\n", | ||
"\n", | ||
" results.append(generation_result)\n", | ||
"\n", | ||
" return results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "bd2ed6f5-e561-4aeb-9427-ae8bb261a772", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Setting variables used by the Llama classes\n", | ||
"import os\n", | ||
"os.environ[\"RANK\"] = \"0\"\n", | ||
"os.environ[\"WORLD_SIZE\"] = \"1\"\n", | ||
"os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n", | ||
"os.environ[\"MASTER_PORT\"] = \"29500\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "b54b4e41-2379-4163-adc6-465e9b7fc57a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def main():\n", | ||
" \n", | ||
" prompts: List[Dict[List[str], AgentType]] = [\n", | ||
" {\n", | ||
" \"prompt\": [\"<Sample user prompt>\"],\n", | ||
" \"agent_type\": AgentType.USER\n", | ||
" },\n", | ||
" {\n", | ||
" \"prompt\": [\"<Sample user prompt>\", \"<Sample agent response>\"],\n", | ||
" \"agent_type\": AgentType.AGENT\n", | ||
" },\n", | ||
" {\n", | ||
" \"prompt\": [\"<Sample user prompt>\", \n", | ||
" \"<Sample agent response>\", \n", | ||
" \"<Sample user reply>\", \n", | ||
" \"<Sample agent response>\"],\n", | ||
" \"agent_type\": AgentType.AGENT\n", | ||
" }\n", | ||
" ]\n", | ||
"\n", | ||
" \n", | ||
" results = llm_eval(prompts)\n", | ||
" \n", | ||
" for i, prompt in enumerate(prompts):\n", | ||
" print(prompt['prompt'])\n", | ||
" print(f\"> {results[0][i]}\")\n", | ||
" print(\"\\n==================================\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "2345c875-d610-4b8a-b73c-464581575e93", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "922f8a8be5bf4ce99b114e83861b08a4", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/json": { | ||
"ascii": false, | ||
"bar_format": null, | ||
"colour": null, | ||
"elapsed": 0.004128694534301758, | ||
"initial": 0, | ||
"n": 0, | ||
"ncols": null, | ||
"nrows": 26, | ||
"postfix": null, | ||
"prefix": "Loading checkpoint shards", | ||
"rate": null, | ||
"total": 3, | ||
"unit": "it", | ||
"unit_divisor": 1000, | ||
"unit_scale": false | ||
}, | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "415fbebfe34f4ad8b3cc30f2a8c97e29", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/conda/envs/llama-recipes/lib/python3.10/site-packages/transformers/generation/utils.py:1518: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\n", | ||
" warnings.warn(\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['<Sample user prompt>']\n", | ||
"> safe\n", | ||
"\n", | ||
"==================================\n", | ||
"\n", | ||
"['<Sample user prompt>', '<Sample agent response>']\n", | ||
"> safe\n", | ||
"\n", | ||
"==================================\n", | ||
"\n", | ||
"['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']\n", | ||
"> safe\n", | ||
"\n", | ||
"==================================\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# used to be able to import this script in another notebook and not run the main function\n", | ||
"if __name__ == '__main__' and '__file__' not in globals():\n", | ||
" from huggingface_hub import login\n", | ||
" login()\n", | ||
" main()\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2a071f81-7e2c-4c87-86f8-abf7dad6ac89", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like inference.py seems to be deleted? Where should this command be run from? Should we provide a
cd
instruction beforepython
command?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the inference.py script I'm referencing here, not the one in the llama_guard directory.