Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export llama failing with errors for runtime errors #2907

Closed
chauhang opened this issue Apr 7, 2024 · 11 comments
Closed

export llama failing with errors for runtime errors #2907

chauhang opened this issue Apr 7, 2024 · 11 comments
Labels
bug Something isn't working good first issue Good for newcomers high priority triage review Items require an triage review

Comments

@chauhang
Copy link

chauhang commented Apr 7, 2024

Export llama is failing with errors for llama and stories models

Error for llama model: Could not import fairseq2 modules....RuntimeError: Trying to create tensor with negative dimension -1: [-1, 4096]
Error for stories model: Could not import fairseq2 modules....RuntimeError: mmap can only be used with files saved with torch.save(./stories/stories110M.pt, _use_new_zipfile_serialization=True), please torch.save your checkpoint with this option in order to use mmap.`

Steps to run for Llama model

Follow the steps from LLM manual
Download the meta versions of llama weights
Run export_llama script

python -m examples.models.llama2.export_llama --checkpoint $MODEL_PATH/consolidated.00.pth --params $MODEL_PATH/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32

Error details for llama2 model export

Could not import fairseq2 modules.
INFO:root:Loading model with checkpoint=/Users/gchauhan/dev/llama-fast/checkpoints/meta-llama/Llama-2-7b/consolidated.00.pth, params=/Users/gchauhan/dev/llama-fast/checkpoints/meta-llama/Llama-2-7b/params.json, use_kv_cache=True, weight_type=WeightType.LLAMA
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama.py", line 30, in <module>
    main()  # pragma: no cover
    ^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama.py", line 26, in main
    export_llama(modelname, args)
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama_lib.py", line 408, in export_llama
    return _export_llama(modelname, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama_lib.py", line 529, in _export_llama
    builder_exported_to_edge = _prepare_for_llama_export(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama_lib.py", line 486, in _prepare_for_llama_export
    load_llama_model(
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/builder.py", line 83, in load_llama_model
    model, example_inputs, _ = EagerModelFactory.create_model(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/model_factory.py", line 44, in create_model
    model = model_class(**kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/model.py", line 139, in __init__
    self.model_ = Transformer(model_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/et/lib/python3.11/site-packages/executorch/examples/models/llama2/llama_transformer.py", line 418, in __init__
    self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/et/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 143, in __init__
    self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/et/lib/python3.11/site-packages/torch/utils/_device.py", line 78, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Trying to create tensor with negative dimension -1: [-1, 4096]

Steps for Stories model

Download the model from the links specified
Run
python -m examples.models.llama2.export_llama -c ./stories/stories110M.pt -p ./stories/params.json

Error details for Stories model export

Could not import fairseq2 modules.
INFO:root:Loading model with checkpoint=./stories/stories110M.pt, params=./stories/params.json, use_kv_cache=False, weight_type=WeightType.LLAMA
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama.py", line 30, in <module>
    main()  # pragma: no cover
    ^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama.py", line 26, in main
    export_llama(modelname, args)
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama_lib.py", line 408, in export_llama
    return _export_llama(modelname, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama_lib.py", line 529, in _export_llama
    builder_exported_to_edge = _prepare_for_llama_export(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/export_llama_lib.py", line 486, in _prepare_for_llama_export
    load_llama_model(
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/builder.py", line 83, in load_llama_model
    model, example_inputs, _ = EagerModelFactory.create_model(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/model_factory.py", line 44, in create_model
    model = model_class(**kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gchauhan/dev/executorch/examples/models/llama2/model.py", line 75, in __init__
    checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/et/lib/python3.11/site-packages/torch/serialization.py", line 1032, in load
    raise RuntimeError("mmap can only be used with files saved with "
RuntimeError: mmap can only be used with files saved with `torch.save(./stories/stories110M.pt, _use_new_zipfile_serialization=True), please torch.save your checkpoint with this option in order to use mmap.

Environment

python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.4.0.dev20240324
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.0
Libc version: N/A

Python version: 3.11.8 (main, Feb 26 2024, 15:36:12) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] executorch==0.1.0
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240324
[pip3] torchao==0.1
[pip3] torchaudio==2.2.0.dev20240324
[pip3] torchsr==1.0.4
[pip3] torchvision==0.19.0.dev20240324
[conda] executorch                0.1.0                    pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] torch                     2.4.0.dev20240324          pypi_0    pypi
[conda] torchao                   0.1                      pypi_0    pypi
[conda] torchaudio                2.2.0.dev20240324          pypi_0    pypi
[conda] torchsr                   1.0.4                    pypi_0    pypi
[conda] torchvision               0.19.0.dev20240324          pypi_0    pypi
@chauhang chauhang changed the title export llama failing with errors for runtime error creating tensor with -ve dimension export llama failing with errors for runtime errors Apr 7, 2024
@iseeyuan iseeyuan added bug Something isn't working good first issue Good for newcomers labels Apr 7, 2024
@iseeyuan
Copy link
Contributor

iseeyuan commented Apr 7, 2024

Thanks @chauhang for reporting this issue! Could you confirm the vocab_size in llama2 7B model's params.json?

@chauhang
Copy link
Author

chauhang commented Apr 7, 2024

@iseeyuan For the meta-llama/Llama-2-7b model the params.json on HF is:

{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}

Also checked for 13b/70b base models and the chat models all of them have vocab_size=-1 in their params.json

@iseeyuan
Copy link
Contributor

iseeyuan commented Apr 7, 2024

@chauhang , It's a bug in our code. We should provide an option so that the export_llama works out of box, given a downloaded folder, either from llama official website, or from HuggingFace.

@pytorch-bot pytorch-bot bot added the triage review Items require an triage review label Apr 7, 2024
@chauhang
Copy link
Author

chauhang commented Apr 7, 2024

Also tested for llama2-7b after updating vocab_size to 32000, getting error AttributeError: '_OpNamespace' 'llama' object has no attribute 'sdpa_with_kv_cache'

Full error logs here

@chauhang
Copy link
Author

chauhang commented Apr 7, 2024

After removing spda param was able to proceed uptill running model on computer. On running the model get error The tokenizer vocab size 84545034 is larger than the model vocab size 32000. .... In function generate(), assert failed (num_prompt_tokens >= 1): Expected at least 1 prompt token

Full logs here

@chauhang
Copy link
Author

chauhang commented Apr 7, 2024

Got the llama2-7b model working on macOS and Android.

Local model runtime on macOS: Model load time: 10.39s, Time to first generated token: 0.739s, Generated token rate: 0.3089 toks/sec
Android Samsung Galaxy S22 runtime: Model load time: 12.05s, Time to first generated token: 8.448s, Generated token rate: 0.0777 toks/sec

Updated list of issues:

LLama2 model

  • vocab_size in params.json from HF downloads is -1, need to manually change to 32000 to proceed forward, update script/readme steps
  • Export with SDPA failed with errors for AttributeError: '_OpNamespace' 'llama' object has no attribute 'sdpa_with_kv_cache'
  • Update readme to add steps for generating tokenizer.bin for llama2 model
  • Optimize local model runtime on macOS (Model load time: 10.39s, Time to first generated token: 0.739s, Generated token rate: 0.3089 toks/sec)
  • Android Emulator -- pte file transfer hangs / creashes emulator for 4gb model file
  • Add steps for running on iOS

Stories Model

  • Fix error RuntimeError: mmap can only be used with files saved with torch.save(./stories/stories110M.pt, _use_new_zipfile_serialization=True)

@iseeyuan
Copy link
Contributor

iseeyuan commented Apr 8, 2024

@chauhang , the second issue, Export with SDPA failed with [errors](https://gist.github.com/chauhang/ca75857c6a152df65b79302fefa1fe2c?permalink_comment_id=5015390#gistcomment-5015390) for AttributeError: '_OpNamespace' 'llama' object has no attribute 'sdpa_with_kv_cache' should have been fixed in main branch over the weekend. Could you pull the updated version and give it another try?
The performance afterwards may also get affected by using sdpa_with_kv_cache.

@kimishpatel
Copy link
Contributor

Also tested for llama2-7b after updating vocab_size to 32000, getting error AttributeError: '_OpNamespace' 'llama' object has no attribute 'sdpa_with_kv_cache'

Might be related to @larryliu0820's diff that got reverted recently

@kimishpatel
Copy link
Contributor

updated

we should just cherry-pick that, right?

mergennachin added a commit to mergennachin/executorch-1 that referenced this issue Apr 8, 2024
Summary: Fixing issues we've seen in pytorch#2907 and pytorch#2805

Differential Revision: D55893925
mergennachin added a commit to mergennachin/executorch-1 that referenced this issue Apr 8, 2024
Summary:

Fixing issues we've seen in pytorch#2907 and pytorch#2805

Differential Revision: D55893925
@mergennachin
Copy link
Contributor

Thanks @chauhang

Some fixes

#2926

facebook-github-bot pushed a commit that referenced this issue Apr 8, 2024
Summary:
Pull Request resolved: #2926

Fixing issues we've seen in #2907 and #2805

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Reviewed By: iseeyuan, cccclai

Differential Revision: D55893925

fbshipit-source-id: c6e0264d868cb487faf02f95ff1bd223cbcc97ac
pytorchbot pushed a commit that referenced this issue Apr 9, 2024
Summary:
Pull Request resolved: #2926

Fixing issues we've seen in #2907 and #2805

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Reviewed By: iseeyuan, cccclai

Differential Revision: D55893925

fbshipit-source-id: c6e0264d868cb487faf02f95ff1bd223cbcc97ac
(cherry picked from commit 6db9d72)
mergennachin added a commit that referenced this issue Apr 9, 2024
Summary:
Pull Request resolved: #2926

Fixing issues we've seen in #2907 and #2805

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Reviewed By: iseeyuan, cccclai

Differential Revision: D55893925

fbshipit-source-id: c6e0264d868cb487faf02f95ff1bd223cbcc97ac
(cherry picked from commit 6db9d72)
@mergennachin
Copy link
Contributor

Things are fixed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers high priority triage review Items require an triage review
Projects
None yet
Development

No branches or pull requests

4 participants