Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to ArcticForCausalLM #6877

Open
maziyarpanahi opened this issue Apr 24, 2024 · 35 comments Β· May be fixed by #7020
Open

Add support to ArcticForCausalLM #6877

maziyarpanahi opened this issue Apr 24, 2024 · 35 comments Β· May be fixed by #7020
Labels
enhancement New feature or request

Comments

@maziyarpanahi
Copy link

maziyarpanahi commented Apr 24, 2024

First open LLM from @SnowflakeDB! Arctic is 480B Dense-MoE with a 10B dense transformer model and a 128x3.66B MoE MLP designed specifically for enterprise AI. πŸ€”

TL;DR:
🧠 480B parameters with 17B active during generation
πŸ‘¨β€πŸ« 128 experts with 2 active in generation
2️⃣ Instruct & Base versions released
πŸ™οΈ Focused on Enterprise task (Code, SQL, Reasoning, Following)
πŸ”“ Released under Apache 2.0
πŸ—» in fp16 ~900GB Memory & in int4 ~240GB
πŸ€— Available on @huggingface

πŸ‹πŸ» Trained with DeepSpeed-MoE

Blog: https://snowflake.com/blog/arctic-open-efficient-foundation-language-models-snowflake/

Models: https://huggingface.co/Snowflake/snowflake-arctic-instruct

@maziyarpanahi maziyarpanahi added the enhancement New feature or request label Apr 24, 2024
@fairydreaming
Copy link
Contributor

I looked briefly into possible support of this model in convert-hf-to-gguf.py and found that there are some tensors causing Can not map tensor errors: model.layers.0.residual_layernorm.weight, model.layers.0.residual_mlp.w1.weight. I checked the Snowflake transformers github arctic branch and found that they can be likely ignored, since they are created only when:

        if self.parallel_attn_mlp_res:
            self.residual_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.residual_mlp =  ArcticMLP(config,

and the default value of parallel_attn_mlp_res is False (see src/transformers/models/arctic/configuration_arctic.py file).

It's possible that only the following changes are needed to make convert-hf-to-gguf.py work:

diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index 5763b666..d02f7aea 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1303,7 +1303,7 @@ class StableLMModel(Model):
             self.gguf_writer.add_tensor(new_name, data)
 
 
-@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
+@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "ArcticForCausalLM")
 class LlamaModel(Model):
     model_arch = gguf.MODEL_ARCH.LLAMA
 
@@ -1345,7 +1345,9 @@ class LlamaModel(Model):
         experts = dict()
         for name, data_torch in self.get_tensors():
             # we don't need these
-            if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
+            if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".residual_layernorm.weight")):
+                continue
+            if re.match(r".*\.residual_mlp\.w\d+\.weight", name):
                 continue
 
             old_dtype = data_torch.dtype

So a brave soul may try to apply this patch and make the quants, but someone with more expertise shall verify this. I ran it on Amazon EC2 instance and it worked up to writing the GGUF file. Good luck!

@christianazinn
Copy link
Contributor

I'm currently downloading the model weights (Instruct) and will see if conversion works on my local machine (don't have access to cloud compute), but it'd take a while. If someone with cloud computing capabilities could take a go, that would be way faster, but I'll try in the meantime.

For what it's worth, it uses the ChatML prompt template and LLaMA 2 tokenizer. I don't know anything else about the architecture yet.

@fairydreaming
Copy link
Contributor

Based on the image in this article it's a hybrid combining a dense transformer with a residual MoE component. I guess a new LLM architecture will be necessary to support this in llama.cpp.

@maziyarpanahi
Copy link
Author

The team behind it (very capable people!) are willing to help making this model more accessible. I have seen their PRs in DeepSpeed to introduce new FP [6,8,12]:

If needed, we can reach out to them for some insight into how to go about this better?

@sorasoras
Copy link

The team behind it (very capable people!) are willing to help making this model more accessible. I have seen their PRs in DeepSpeed to introduce new FP [6,8,12]:

If needed, we can reach out to them for some insight into how to go about this better?

I was thinking only load dense part onto GPU and leave moe part on CPU ram. That way a lot of people can run this in decent speed

@maziyarpanahi
Copy link
Author

That's actually not a bad idea! Would it be possible to make this more dynamic in all MoE models, allowing users to select whether the experts to be offloaded on GPUs/CPUs?

@christianazinn
Copy link
Contributor

Not an expert (haha) but I don't know if there's any way to determine which layers correspond to which experts.

In the first place, the way MoE inference works shouldn't really allow this to be possible - the gate network sends different tokens to different experts, so you'd need all those experts to be on GPU for any increase in inference speed. I don't know how it works with the additional dense part in this particular architecture, but I assume it would get bottlenecked by the experts; I would be happy to be wrong, though.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 25, 2024

I was able to convert and quant this model to Q8 with convert.py --skip-unknown, but needed to increase LLAMA_MAX_EXPERTS from 60 to at least 128 in llama.cpp to get it to load the model.
The output, however, was only partially coherent.
I have attached the output from loading with current up-to-date llama.cpp for reference
sfa-q8-test-1

@christianazinn
Copy link
Contributor

I was able to convert and quant this model to Q8 with convert.py --skip-unknown, but needed to increase LLAMA_MAX_EXPERTS from 60 to at least 128 in llama.cpp to get it to load the model. The output, however, was only partially coherent. I have attached the output from loading with current up-to-date llama.cpp for reference sfa-q8-test-1

Can you post some examples of what you mean by "partially coherent," and might conversion with LLAMA_MAX_EXPERTS=60 have affected it? And how large is the Q8_0 file?

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 25, 2024

might conversion with LLAMA_MAX_EXPERTS=60 have affected it?
The conversion script didn't need any modifications to run, although whether it worked as intended is unknown.
how large is the Q8_0 file?
Q8 is 472G.
Can you post some examples of what you mean by "partially coherent,"
Response to a simple -p "Hello" is below. "Partially coherent" might have been giving it too much credit:
sfa-q8-response-1

@christianazinn
Copy link
Contributor

might conversion with LLAMA_MAX_EXPERTS=60 have affected it?
The conversion script didn't need any modifications to run, although whether it worked as intended is unknown.
how large is the Q8_0 file?
Q8 is 472G.
Can you post some examples of what you mean by "partially coherent,"
Response to a simple -p "Hello" is below. "Partially coherent" might have been giving it too much credit:
sfa-q8-response-1

Thanks. --skip-unknown probably did the same as @fairydreaming's modification, and hopefully it defaulted to assuming it was a Llama architecture model. It seems like it at least generates tokens properly, which is a plus! :P convert-hf-to-gguf.py with the suggested edits is underway on my end to f16, and I'll see if I get the same results. It's likely an issue with the hybrid architecture.

@fairydreaming
Copy link
Contributor

I was able to convert and quant this model to Q8 with convert.py --skip-unknown, but needed to increase LLAMA_MAX_EXPERTS from 60 to at least 128 in llama.cpp to get it to load the model. The output, however, was only partially coherent. I have attached the output from loading with current up-to-date llama.cpp for reference sfa-q8-test-1

I wonder why does it display BOS and EOS tokens as some weird characters instead of <|im_start|> and <|im_end|>.

@fairydreaming
Copy link
Contributor

@fairydreaming
Copy link
Contributor

I think ignoring these tensors likely won't work after all, I just found that in https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/config.json they have parallel_attn_mlp_res set to true, so residual_layernorm and residual_mlp are created and used. Sorry for the confusion.

@BarfingLemurs
Copy link
Contributor

BarfingLemurs commented Apr 26, 2024

@fairydreaming

I got your Q4_K_M and stitched it into a singular file. Does it output tokens on your end? I didn't have the required RAM, but was hoping for the model to run on disc.

GGML_ASSERT: /home/user/llama.cpp/llama.cpp:3728: hparams.n_expert <= LLAMA_MAX_EXPERTS Could not attach to process. If your uid matches the uid of the target process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try again as the root user. For more details, see /etc/sysctl.d/10-ptrace.conf ptrace: Operation not permitted.

@phymbert
Copy link
Collaborator

phymbert commented Apr 26, 2024

stitched it into a singular file

You do not need to merge the sharded model.

GGML_ASSERT: /home/user/llama.cpp/llama.cpp:3728: hparams.n_expert <= LLAMA_MAX_EXPERTS

probably we need to remove this limit or at least increase it to 128.

@fairydreaming
Copy link
Contributor

fairydreaming commented Apr 26, 2024

@fairydreaming

I got your Q4_K_M and stitched it into a singular file. Does it output tokens on your end? I didn't have the required RAM, but was hoping for the model to run on disc.

Yes, it loaded and generated some tokens (there is no need to join the files), but the output was nonsensical like in @cpumaxx screenshot. I think there's no chance for it to work correctly without using these layers that we ignored.

If there is someone running the model with HF transformers library it would be interesting to see how does the model behave with parallel_attn_mlp_res set to false in config.json.

@fairydreaming
Copy link
Contributor

It's alive!
arctic

@BarfingLemurs
Copy link
Contributor

πŸ‘πŸŒšπŸ‘
What did you do?!

@fairydreaming
Copy link
Contributor

You can try my branch: https://github.com/fairydreaming/llama.cpp/tree/snowflake-arctic
Of yourse you will have to convert the model to GGUF (again), the ones I posted before won't work.
It's a work in progress so things will break. Tensor mapping is a mess, I'm not sure yet how to map post_attention_layernorm and residual_layernorm without breaking other models. Also I think BOS and EOS tokens are wrong in tokenizer.model file, that's why we see these strange kanji characters.

@BarfingLemurs
Copy link
Contributor

You can try my branch: https://github.com/fairydreaming/llama.cpp/tree/snowflake-arctic
Of yourse you will have to convert the model to GGUF (again), the ones I posted before won't work.

I might. Keep us updated!

@fairydreaming
Copy link
Contributor

fairydreaming commented Apr 27, 2024

@ggerganov can you offer any advice about how to map snowflake-arctic-instruct tensors in convert-hf-to-gguf.py? This model has attention + FFN and MoE running in parallel as shown on this image.

Since there are both FFN and MoE blocks in the model I decided to map the FFN blocks to the usual FFN_GATE, FFN_DOWN and FFN_UP blocks, and leave the MoE mapping same as in Mixtral. The problem is mainly in mapping of the normalization blocks. There are three of them:

  1. Pre-attention normalization named input_layernorm in TF model. There is no problem with this one, it's mapped to ATTN_NORM
  2. Post-attention/pre-FFN normalization named residual_layernorm in TF model. Since I mapped FFN network blocks to FFN_GATE, FFN_DOWN and FFN_UP, following this convention I mapped residual_layernorm to FFN_NORM. No problem so far.
  3. Pre-MoE normalization named post_attention_layernorm in TF model - this name is very confusing since it normalizes the layer input, not the attention output, I guess it's a copy/paste leftover in arctic code from the Mixtral model. There is a problem with this one - post_attention_layernorm is mapped to the FFN_NORM by default.

Currently in my branch I temporarily removed post_attention_layernorm from the FFN_NORM mappings and moved it to a newly created MODEL_TENSOR.FFN_NORM_EXP mapping (following the FFN_GATE_EXP, FFN_UP_EXP, FFN_DOWN_EXP convention). This of course breaks other architectures.

Some ideas I had about fixing this situation:

  • Make the mappings architecture-dependent, for example by creating the ArcticTensorNameMap class for this particular architecture. I guess it's a good idea to start doing this anyway, since the global mapping is becoming quite messy.
  • Leave the post_attention_layernorm mapped to FFN_NORM and map residual_layernorm to something else. But I don't see any good candidates for this. Also I don't want to make the mapping confusing.

Would be grateful for any ideas.

@compilade
Copy link
Collaborator

compilade commented Apr 27, 2024

Would be grateful for any ideas.

@fairydreaming I have an idea for a (temporary?) workaround which would not break other models:

Keep post_attention_layernorm as it was before in the tensor mapping, and in the write_tensors method of class ArticModel, rename tensor names for which name.endswith("post_attention_layernorm.weight") to something else before calling tensor_map.get_name(name, try_suffixes=(".weight", ".bias")), so that it doesn't get mapped to FFN_NORM.

This way, you could use your newly-added MODEL_TENSOR.FFN_NORM_EXP for it, without breaking existing architectures.

Make the mappings architecture-dependent, for example by creating the ArcticTensorNameMap class for this particular architecture. I guess it's a good idea to start doing this anyway, since the global mapping is becoming quite messy.

Yes, architecture-dependent mappings seem like a good idea to me. It seems worthwhile to explore in the future. To make it cleaner than making new classes for each architecture, it could possibly be a field in the base Model class so that common mappings would still be possible in gguf-py/gguf/tensor_mapping.py while architecture-dependent mappings would be more obvious than the ad-hoc renaming of my above suggestion.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 27, 2024

You can try my branch: https://github.com/fairydreaming/llama.cpp/tree/snowflake-arctic
Of yourse you will have to convert the model to GGUF (again), the ones I posted before won't work.

I'm going to try this once a couple of other experiments I have going finish. Thanks for getting something sensible out of this. I'm really excited to see performance of this model on some of my workloads!

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 28, 2024

@fairydreaming Are you still using ./convert.py with --skip-unknown, or is there a different command to convert to gguf?
edt: nm, I'm currently converting with convert-hf

@fairydreaming
Copy link
Contributor

@fairydreaming Are you still using ./convert.py with --skip-unknown, or is there a different command to convert to gguf?

@cpumaxx no, in my branch I added support for these tensors (they are vital to the model), so --skip-unknown is not needed. I used the convert-hf-to-gguf.py script. However, the official tokenizer.model file is broken (it has wrong BOS/EOS tokens), so I kind of forced the conversion script to use the tokenizer.json instead by removing tokenizer.model from the model directory (some additional minor tweaks were needed for this). Then I set tokenizer.ggml.add_bos_token in the resulting GGUF to False with gguf-set-metadata.py. So unfortunately the model conversion process is still messy and long (like hours per iteration). :(

Now I see that you can select the --vocab-type in convert.py, so maybe all this is not needed. Will try it the next time I will convert the model. I'd try it now, but I have no disk space left. I thought that 4TB SSD will be enough to play with LLMs, but after recent model releases I'm running out of space on all my disks.

@fairydreaming
Copy link
Contributor

@compilade

Make the mappings architecture-dependent, for example by creating the ArcticTensorNameMap class for this particular architecture. I guess it's a good idea to start doing this anyway, since the global mapping is becoming quite messy.

Yes, architecture-dependent mappings seem like a good idea to me. It seems worthwhile to explore in the future. To make it cleaner than making new classes for each architecture, it could possibly be a field in the base Model class so that common mappings would still be possible in gguf-py/gguf/tensor_mapping.py while architecture-dependent mappings would be more obvious than the ad-hoc renaming of my above suggestion.

For now I settled on yet another solution. I added in a TensorNameMap:

    arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
        MODEL_ARCH.ARCTIC: {
            MODEL_TENSOR.TOKEN_EMBD: (
                "model.embed_tokens",
            ),
            MODEL_TENSOR.OUTPUT_NORM: (
                "model.norm",
            ),
            ...

which is the same as block_mappings_cfg, but architecture-specific. Then in TensorNameMap init method I do:

        if arch in self.arch_block_mappings_cfg:
            block_mappings = self.arch_block_mappings_cfg[arch]
        else:
            block_mappings = self.block_mappings_cfg

and use block_mappings later. This seemed like the least intrusive and the cleanest solution.

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 29, 2024

@fairydreaming Before I dedicate time/space/compute to trying, does this patch mean that I can convert to a usable gguf in some form?
What is the confirmed command and parameters to get there?
I'm itching to be able to really test this sucker out

@fairydreaming
Copy link
Contributor

@cpumaxx I uploaded corrected quants, you can try them with my snowflake-arctic branch.
https://huggingface.co/sszymczyk/snowflake-arctic-instruct-GGUF

@cpumaxx
Copy link
Contributor

cpumaxx commented Apr 29, 2024

I usually like to do all my own conversions and quants from the official safetensors, but I'll make an exception to test this one out : )
Thanks for your work and bandwidth on this!

@fairydreaming
Copy link
Contributor

@cpumaxx If you want to do your own quants then convert-hf-to-gguf.py shall now work correctly. The only remaining problem is that add_bos_token is unnecessarily (I think) set to true, but you can change that after conversion with:
python3 gguf-py/scripts/gguf-set-metadata.py models/snowflake-arctic-instruct.gguf tokenizer.ggml.add_bos_token ""
The convert.py also works (you have to add --vocab-type hfft to use the correct vocab), but needs few changes:

diff --git a/convert.py b/convert.py
index 1c700cf..e42a67d 100755
--- a/convert.py
+++ b/convert.py
@@ -40,7 +40,7 @@ if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
 
 NDArray: TypeAlias = 'np.ndarray[Any, Any]'
 
-ARCH = gguf.MODEL_ARCH.LLAMA
+ARCH = gguf.MODEL_ARCH.ARCTIC
 
 DEFAULT_CONCURRENCY = 8
 
@@ -553,7 +553,6 @@ class LlamaHfVocab(Vocab):
             cache_dir=base_path,
             local_files_only=True,
         )
-        assert self.tokenizer.is_fast  # assume tokenizer.json is used
 
         # Initialize lists and dictionaries for added tokens
         self.added_tokens_list = []

I guess it was originally intended only for MODEL_ARCH.LLAMA, so I'm not going to commit these changes.

@fairydreaming fairydreaming linked a pull request May 1, 2024 that will close this issue
@BarfingLemurs
Copy link
Contributor

@fairydreaming couldn't convert the downloaded instruct model. Transformers 4.40.1

I was running with `python convert-hf-to-gguf.py /media/user/6/snowflake_instruct/Snowflake_snowflake-arctic-instruct --out type f16

/media/user/6/snowflake_instruct/Snowflake_snowflake-arctic-instruct contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co//media/user/6/snowflake_instruct/Snowflake_snowflake-arctic-instruct.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Traceback (most recent call last):
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 3114, in <module>
    main()
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 3101, in main
    model_instance.set_vocab()
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 1524, in set_vocab
    self._set_vocab_llama_hf()
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 462, in _set_vocab_llama_hf
    vocab = LlamaHfVocab(self.dir_model)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/snowflake/llama.cpp/convert.py", line 556, in __init__
    assert self.tokenizer.is_fast  # assume tokenizer.json is used
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

@fairydreaming
Copy link
Contributor

You can simply comment out the assertion.

@fairydreaming
Copy link
Contributor

@cebtenzzre can you tell us the story behind this assert? It looks like the ArcticTokenizer that has LlamaTokenizer base class (and there is no ArcticTokenizerFast) fails the assertion. In the convert.py file few lines above the assertion there is:
# Allow the tokenizer to default to slow or fast versions.
But your assert forces the tokenizer to be "fast", why is that?

@fairydreaming
Copy link
Contributor

@fairydreaming couldn't convert the downloaded instruct model. Transformers 4.40.1

I was running with `python convert-hf-to-gguf.py /media/user/6/snowflake_instruct/Snowflake_snowflake-arctic-instruct --out type f16

/media/user/6/snowflake_instruct/Snowflake_snowflake-arctic-instruct contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co//media/user/6/snowflake_instruct/Snowflake_snowflake-arctic-instruct.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Traceback (most recent call last):
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 3114, in <module>
    main()
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 3101, in main
    model_instance.set_vocab()
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 1524, in set_vocab
    self._set_vocab_llama_hf()
  File "/home/user/snowflake/llama.cpp/convert-hf-to-gguf.py", line 462, in _set_vocab_llama_hf
    vocab = LlamaHfVocab(self.dir_model)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/snowflake/llama.cpp/convert.py", line 556, in __init__
    assert self.tokenizer.is_fast  # assume tokenizer.json is used
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

@BarfingLemurs I commited the assertion removal to my snowflake-arctic branch. Thanks for reporting the problem. I had this line removed locally (but not commited), so I didn't notice that it affected convert-hf-to-gguf.py as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants