Skip to content

Commit

Permalink
Add instructions to convert Hugging Face models to PyTorch (#3523)
Browse files Browse the repository at this point in the history
Summary:
As titled. It's pretty common that users download the LLM models in safetensor format. Add instructions and example script to convert them to PyTorch format so that export_llama script can accept. It leverages the utils from TorchTune.

Thanks l3utterfly and kartikayk for the discussions and suggestions!

More context in #3303

Pull Request resolved: #3523

Reviewed By: mergennachin

Differential Revision: D57026658

Pulled By: iseeyuan

fbshipit-source-id: 11badf709920ff945cdfdd2b244c52c750943412
  • Loading branch information
iseeyuan authored and facebook-github-bot committed May 8, 2024
1 parent 818b178 commit 2c1e283
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions examples/models/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,33 @@ You can export and run the original Llama3 8B model.

Due to the larger vocabulary size of Llama3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` to further reduce the model size.

### Option D: Download models from Hugging Face and convert from safetensor format to state dict

You can also download above models from [Hugging Face](https://huggingface.co/). Since ExecuTorch starts from a PyTorch model, a script like below can be used to convert the Hugging Face safetensors format to PyTorch's state dict. It leverages the utils provided by [TorchTune](https://github.com/pytorch/torchtune).

```Python
from torchtune.utils import FullModelHFCheckpointer
from torchtune.models import convert_weights
import torch

# Convert from safetensors to TorchTune. Suppose the model has been downloaded from Hugging Face
checkpointer = FullModelHFCheckpointer(
checkpoint_dir='/home/.cache/huggingface/hub/models/snapshots/hash-number',
checkpoint_files=['model-00001-of-00002.safetensors', 'model-00002-of-00002.safetensors'],
output_dir='/the/destination/dir' ,
model_type='LLAMA3' # or other types that TorchTune supports
)

print("loading checkpoint")
sd = checkpointer.load_checkpoint()

# Convert from TorchTune to Meta (PyTorch native)
sd = convert_weights.tune_to_meta(sd['model'])

print("saving checkpoint")
torch.save(sd, "/the/destination/dir/checkpoint.pth")
```

## (Optional) Finetuning

If you want to finetune your model based on a specific dataset, PyTorch provides [TorchTune](https://github.com/pytorch/torchtune) - a native-Pytorch library for easily authoring, fine-tuning and experimenting with LLMs.
Expand Down

0 comments on commit 2c1e283

Please sign in to comment.