Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for the ArcticForCausalLM. #7020

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

fairydreaming
Copy link
Contributor

Fixes #6877

Contains the following changes:

  • increases maximum number of experts from 60 to 128
  • adds new tensor type FFN_NORM_EXP (for a normalization block before MoE that runs in parallel to the attention + FFN, see Add support to ArcticForCausalLM #6877 for details)
  • introduces architecture-specific block mappings in gguf-py (details in Add support to ArcticForCausalLM #6877)
  • adds new model type MODEL_10B_128x3_66B
  • adds new ARCTIC architecture and a general support for models based on this architecture

Model files for testing: https://huggingface.co/sszymczyk/snowflake-arctic-instruct-GGUF

Copy link
Contributor

github-actions bot commented May 1, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 555 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8412.86ms p(95)=20062.27ms fails=, finish reason: stop=485 truncated=70
  • Prompt processing (pp): avg=90.32tk/s p(95)=384.11tk/s
  • Token generation (tg): avg=33.28tk/s p(95)=45.5tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=snowflake-arctic-clean commit=9acc3ecf34ca7ce579965876e57ed62201c2b95a

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715277769 --> 1715278391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 344.7, 344.7, 344.7, 344.7, 344.7, 397.52, 397.52, 397.52, 397.52, 397.52, 447.53, 447.53, 447.53, 447.53, 447.53, 482.12, 482.12, 482.12, 482.12, 482.12, 546.57, 546.57, 546.57, 546.57, 546.57, 552.2, 552.2, 552.2, 552.2, 552.2, 559.33, 559.33, 559.33, 559.33, 559.33, 594.84, 594.84, 594.84, 594.84, 594.84, 600.17, 600.17, 600.17, 600.17, 600.17, 606.2, 606.2, 606.2, 606.2, 606.2, 615.33, 615.33, 615.33, 615.33, 615.33, 640.53, 640.53, 640.53, 640.53, 640.53, 668.4, 668.4, 668.4, 668.4, 668.4, 695.31, 695.31, 695.31, 695.31, 695.31, 640.08, 640.08, 640.08, 640.08, 640.08, 646.09, 646.09, 646.09, 646.09, 646.09, 648.16, 648.16, 648.16, 648.16, 648.16, 634.58, 634.58, 634.58, 634.58, 634.58, 641.9, 641.9, 641.9, 641.9, 641.9, 651.19, 651.19, 651.19, 651.19, 651.19, 657.29, 657.29, 657.29, 657.29, 657.29, 671.69, 671.69, 671.69, 671.69, 671.69, 680.7, 680.7, 680.7, 680.7, 680.7, 683.89, 683.89, 683.89, 683.89, 683.89, 687.85, 687.85, 687.85, 687.85, 687.85, 706.46, 706.46, 706.46, 706.46, 706.46, 705.43, 705.43, 705.43, 705.43, 705.43, 708.3, 708.3, 708.3, 708.3, 708.3, 710.1, 710.1, 710.1, 710.1, 710.1, 718.0, 718.0, 718.0, 718.0, 718.0, 719.6, 719.6, 719.6, 719.6, 719.6, 725.29, 725.29, 725.29, 725.29, 725.29, 731.51, 731.51, 731.51, 731.51, 731.51, 741.55, 741.55, 741.55, 741.55, 741.55, 757.71, 757.71, 757.71, 757.71, 757.71, 762.9, 762.9, 762.9, 762.9, 762.9, 762.09, 762.09, 762.09, 762.09, 762.09, 762.49, 762.49, 762.49, 762.49, 762.49, 767.54, 767.54, 767.54, 767.54, 767.54, 768.51, 768.51, 768.51, 768.51, 768.51, 767.49, 767.49, 767.49, 767.49, 767.49, 760.21, 760.21, 760.21, 760.21, 760.21, 735.91, 735.91, 735.91, 735.91, 735.91, 735.19, 735.19, 735.19, 735.19, 735.19, 738.5, 738.5, 738.5, 738.5, 738.5, 738.44, 738.44, 738.44, 738.44, 738.44, 739.1, 739.1, 739.1, 739.1, 739.1, 746.62, 746.62, 746.62, 746.62, 746.62, 746.71, 746.71, 746.71, 746.71, 746.71, 751.34, 751.34, 751.34, 751.34, 751.34, 755.89, 755.89, 755.89, 755.89, 755.89, 755.84, 755.84, 755.84, 755.84, 755.84, 759.2, 759.2, 759.2, 759.2, 759.2, 760.12, 760.12, 760.12, 760.12, 760.12, 760.69, 760.69, 760.69, 760.69, 760.69, 762.09, 762.09, 762.09, 762.09, 762.09, 762.78, 762.78, 762.78, 762.78, 762.78, 765.41, 765.41, 765.41, 765.41, 765.41, 768.64, 768.64, 768.64, 768.64, 768.64, 769.66, 769.66, 769.66, 769.66, 769.66, 769.53, 769.53]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715277769 --> 1715278391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 39.46, 39.46, 39.46, 39.46, 39.46, 42.91, 42.91, 42.91, 42.91, 42.91, 35.18, 35.18, 35.18, 35.18, 35.18, 34.8, 34.8, 34.8, 34.8, 34.8, 34.59, 34.59, 34.59, 34.59, 34.59, 34.68, 34.68, 34.68, 34.68, 34.68, 35.67, 35.67, 35.67, 35.67, 35.67, 36.28, 36.28, 36.28, 36.28, 36.28, 36.3, 36.3, 36.3, 36.3, 36.3, 36.28, 36.28, 36.28, 36.28, 36.28, 35.63, 35.63, 35.63, 35.63, 35.63, 35.36, 35.36, 35.36, 35.36, 35.36, 34.94, 34.94, 34.94, 34.94, 34.94, 34.78, 34.78, 34.78, 34.78, 34.78, 33.82, 33.82, 33.82, 33.82, 33.82, 33.58, 33.58, 33.58, 33.58, 33.58, 33.96, 33.96, 33.96, 33.96, 33.96, 33.85, 33.85, 33.85, 33.85, 33.85, 33.75, 33.75, 33.75, 33.75, 33.75, 33.59, 33.59, 33.59, 33.59, 33.59, 33.63, 33.63, 33.63, 33.63, 33.63, 33.71, 33.71, 33.71, 33.71, 33.71, 33.39, 33.39, 33.39, 33.39, 33.39, 33.23, 33.23, 33.23, 33.23, 33.23, 33.31, 33.31, 33.31, 33.31, 33.31, 33.33, 33.33, 33.33, 33.33, 33.33, 32.79, 32.79, 32.79, 32.79, 32.79, 32.7, 32.7, 32.7, 32.7, 32.7, 32.94, 32.94, 32.94, 32.94, 32.94, 33.04, 33.04, 33.04, 33.04, 33.04, 33.18, 33.18, 33.18, 33.18, 33.18, 33.35, 33.35, 33.35, 33.35, 33.35, 33.43, 33.43, 33.43, 33.43, 33.43, 33.33, 33.33, 33.33, 33.33, 33.33, 33.08, 33.08, 33.08, 33.08, 33.08, 32.77, 32.77, 32.77, 32.77, 32.77, 32.7, 32.7, 32.7, 32.7, 32.7, 32.88, 32.88, 32.88, 32.88, 32.88, 32.94, 32.94, 32.94, 32.94, 32.94, 33.06, 33.06, 33.06, 33.06, 33.06, 33.16, 33.16, 33.16, 33.16, 33.16, 32.85, 32.85, 32.85, 32.85, 32.85, 32.46, 32.46, 32.46, 32.46, 32.46, 32.34, 32.34, 32.34, 32.34, 32.34, 31.48, 31.48, 31.48, 31.48, 31.48, 31.34, 31.34, 31.34, 31.34, 31.34, 31.36, 31.36, 31.36, 31.36, 31.36, 31.43, 31.43, 31.43, 31.43, 31.43, 31.5, 31.5, 31.5, 31.5, 31.5, 31.63, 31.63, 31.63, 31.63, 31.63, 31.61, 31.61, 31.61, 31.61, 31.61, 31.38, 31.38, 31.38, 31.38, 31.38, 31.32, 31.32, 31.32, 31.32, 31.32, 31.4, 31.4, 31.4, 31.4, 31.4, 31.53, 31.53, 31.53, 31.53, 31.53, 31.64, 31.64, 31.64, 31.64, 31.64, 31.72, 31.72, 31.72, 31.72, 31.72, 31.78, 31.78, 31.78, 31.78, 31.78, 31.83, 31.83, 31.83, 31.83, 31.83, 31.83, 31.83, 31.83, 31.83, 31.83, 31.84, 31.84]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715277769 --> 1715278391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.12, 0.12, 0.12, 0.12, 0.45, 0.45, 0.45, 0.45, 0.45, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.21, 0.21, 0.21, 0.21, 0.21, 0.34, 0.34, 0.34, 0.34, 0.34, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.1, 0.1, 0.1, 0.1, 0.1, 0.24, 0.24, 0.24, 0.24, 0.24, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.16, 0.16, 0.16, 0.16, 0.16, 0.26, 0.26, 0.26, 0.26, 0.26, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.34, 0.34, 0.34, 0.34, 0.34, 0.24, 0.24, 0.24, 0.24, 0.24, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.11, 0.11, 0.11, 0.11, 0.11, 0.22, 0.22, 0.22, 0.22, 0.22, 0.19, 0.19, 0.19, 0.19, 0.19, 0.15, 0.15, 0.15, 0.15, 0.15, 0.07, 0.07, 0.07, 0.07, 0.07, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.31, 0.31, 0.31, 0.31, 0.31, 0.58, 0.58, 0.58, 0.58, 0.58, 0.48, 0.48, 0.48, 0.48, 0.48, 0.44, 0.44, 0.44, 0.44, 0.44, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.2, 0.2, 0.2, 0.2, 0.2, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715277769 --> 1715278391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0]
                    

@sorasoras
Copy link

It's possible to only offload dense part of the model onto GPU

model_arch = gguf.MODEL_ARCH.ARCTIC

def set_vocab(self):
self._set_vocab_llama_hf()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: #6877 (comment), this should be:

Suggested change
self._set_vocab_llama_hf()
try:
self. _set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_llama_hf()

The assertion exists because LlamaHfVocab was primarily written to convert HF "fast" tokenizers with a tokenizer.json. Since before it existed, "slow" sentencepiece tokenizers with a tokenizer.model have (almost?) always been converted using SentencePieceProcessor, which doesn't depend on HF transformers and directly preserves the token types and scores.

If you want to start converting slow tokenizers using HfVocab as well, I won't stop you, but in order to be consistent you'd have to remove all references to SentencePieceProcessor in the convert scripts, and make HF transformers a hard requirement for converting models with a Llama vocab. Otherwise, we'd be making an exception for this model for no clear reason.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reason is that the official tokenizer.model file for snowflake-arctic-instruct contains wrong BOS and EOS tokens as confirmed in: https://huggingface.co/Snowflake/snowflake-arctic-instruct/discussions/12
That's why I used llama_hf vocab that reads tokens from json files instead. If there is a better solution for this I'm fully open to any suggestions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cebtenzzre What if I implement ArcticModel::set_vocab() myself like XverseForCausalLM did, is that acceptable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cebtenzzre I now load vocabulary with SentencePieceProcessor as you suggested and apply necessary token modifications based on added_tokens_decoder field from tokenizer_config.json.

@mofosyne mofosyne added enhancement New feature or request review complexity : medium Generally require more time to grok but manageable by beginner to medium expertise level labels May 9, 2024
convert-hf-to-gguf.py Outdated Show resolved Hide resolved
convert-hf-to-gguf.py Outdated Show resolved Hide resolved
Comment on lines 454 to 455
if arch in self.arch_block_mappings_cfg:
block_mappings = self.arch_block_mappings_cfg[arch]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means architecture-specific block mappings can't partially override the common mappings (they have to totally re-define everything)?

Maybe this is fixable by adding the common mappings first to self.mapping, then the architecture-specific mappings?

So maybe using the union operator for dicts would be appropriate here

if arch in self.arch_block_mappings_cfg:
    block_mappings = self.block_mappings_cfg | self.arch_block_mappings_cfg[arch]

But that's only supported since Python 3.9, and gguf-py targets python = ">=3.8"

In this case using {**x, **y} instead of x | y would be more compatible for older-than-3.9 versions of Python, and would allow making a new dict with the content of x augmented/overridden by y. But the new syntax is clearer in my opinion.

After that, the architecture-specific mapping of MODEL_ARCH.ARCTIC should be simpler (since they won't need to include duplicates of the common mappings).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the idea is to keep only "conflicting" block mappings in architecture-specific mappings and "non-conflicting" mappings in general mappings? I think using dict.update() is a better idea then. Mappings for ARCTIC arch would be shortened to:

    # architecture-specific block mappings
    arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
        MODEL_ARCH.ARCTIC: {
            MODEL_TENSOR.FFN_NORM: (
                "model.layers.{bid}.residual_layernorm",
            ),
            MODEL_TENSOR.FFN_NORM_EXP: (
                "model.layers.{bid}.post_attention_layernorm",
            ),
        },
    }

while in the TensorNameMap init we would only have to add:

        if arch in self.arch_block_mappings_cfg:
            self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the idea is to keep only "conflicting" block mappings in architecture-specific mappings and "non-conflicting" mappings in general mappings?

Yes, exactly.

What do you think?

I think using dict.update() would be good. My proposed approach would have made a copy of the dict, but you're right, updating in-place would work too and would be better, since the original block_mappings_cfg isn't used later on (I think?).

I agree with using dict.update() for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not test this (the model is quite big), but the code looks good to me. Nice work @fairydreaming!

@@ -181,6 +182,7 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto()
SSM_D = auto()
SSM_OUT = auto()
FFN_NORM_EXP = auto()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the actual numbers associated to the enum values of MODEL_TENSOR don't really matter (their names (from TENSOR_NAMES) are used instead in GGUF), maybe FFN_NORM_EXP could be placed right before FFN_GATE_EXP, a bit like FFN_NORM is right before FFN_GATE, for consistency.

If this is changed, it should also be placed similarly in TENSOR_NAMES and MODEL_TENSORS[MODEL.ARCTIC] in gguf-py/gguf/constants.py as well as in the llm_tensor enum, the LLM_TENSOR_NAMES mapping, and the llama_layer struct (and maybe the LLM_ARCH_ARCTIC case in llm_load_tensors?) in llama.cpp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the order as requested, but in llama_layer struct the order is different, so I didn't touch it. In llm_load_tensors I think it was already in the requested order.

@fairydreaming
Copy link
Contributor Author

I noticed that the arctic model doesn't use bias tensors, so I removed usage of bias tensors in the LLM_ARCH_ARCTIC-related code (they were all nulls anyway).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request review complexity : medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support to ArcticForCausalLM
7 participants