Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suport for Jamba JambaForCausalLM #6372

Open
4 tasks done
maziyarpanahi opened this issue Mar 28, 2024 · 18 comments · May be fixed by #7531
Open
4 tasks done

Suport for Jamba JambaForCausalLM #6372

maziyarpanahi opened this issue Mar 28, 2024 · 18 comments · May be fixed by #7531
Labels
enhancement New feature or request

Comments

@maziyarpanahi
Copy link

maziyarpanahi commented Mar 28, 2024

Prerequisites

Please answer the following questions for yourself before submitting an issue.

  • I am running the latest code. Development is very rapid so there are no tagged versions as of now.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new bug or useful enhancement to share.

Feature Description

Please provide a detailed written description of what you were trying to do, and what you expected llama.cpp to do as an enhancement.

A new MoE model was released today: JambaForCausalLM https://huggingface.co/ai21labs/Jamba-v0.1

Motivation

Please provide a detailed written description of reasons why this feature is necessary and how it is useful to llama.cpp users.

Another very good and open LLM

Possible Implementation

If you have an idea as to how it can be implemented, please write a detailed description. Feel free to give links to external sources or share visuals that might be helpful to understand the details better.

I can test any PR candidate

@maziyarpanahi maziyarpanahi added the enhancement New feature or request label Mar 28, 2024
@nonetrix
Copy link

Have smaller Mamba based LLMs already been added in the past?

@Green-Sky
Copy link
Collaborator

@compilade added mamba support. But Jamba seems to be a derivative and needs code modifications.

@compilade
Copy link
Collaborator

I'd like it very much if they released a smaller version of their model. I don't have enough RAM to run Mixtral (only have 8GB), and Jamba seems to be around the same size as Mixtral. A model with less than 1B total parameters (or even less than 200M) would be ideal for quickly figuring out implementation problems (and would waste much less disk space when debugging or modifying model conversion).

My free time is too scarce at the moment to work on this (until May). The KV cache of this model will be some complicated beast (it's both recurrent and attention-based, but never in the same "layer". This will require rethinking how the KV cache is allocated, and how Mamba's state is stored), but I think it should still be possible to support eventually, given enough effort.

Similarly to llm_build_ffn, I think there will need to be some kind of llm_build_mamba to more easily share the code building the graph of a Mamba block between Mamba and Jamba.

Anyone wanting to work on this should start by building a strong mental model of how Mamba's state is managed in llama.cpp, as well as how the KV cache works (at least what goes where, not necessarily why). This is necessary because modifications of both of these will likely be needed to make this work.

Mamba in llama.cpp uses 1 KV cell per sequence (we'll probably need to introduce some other tensor lists than k_l and v_l in llama_kv_cache to avoid conflicting with attention's one KV cells per token (a different set of cells will be required (and yet another session file format revision))). Sequences are selected with inp_s_seq in ggml_compute_forward_ssm_conv_f32 and ggml_compute_forward_ssm_scan_f32. Each token from a batch has one input state/sequence, but the resulting state is copied to all the sequences assigned to that token.

Simplifying how recurrent state operations are implemented is on my TODO list, and implementing both Jamba and RWKV should help with refactoring, but Jamba support in llama.cpp feels like a multi-week project, and I'll only have this kind of free time in May.

If anyone's too impatient, feel free to experiment and figure out a way to make Jamba work with llama.cpp. Even incomplete proofs of concept of how to manage the Jamba blocks should be useful.

@maziyarpanahi
Copy link
Author

@compilade added mamba support. But Jamba seems to be a derivative and needs code modifications.

for reference: #5328

@sorasoras
Copy link

Have smaller Mamba based LLMs already been added in the past?

It's not mamba based any more. it's a mix up between transformer and mamba so that's gonna be different.

@trap20
Copy link

trap20 commented Apr 1, 2024

There is a Mini-Jamba on Huggingface now: https://huggingface.co/TechxGenus/Mini-Jamba-v2

Might be helpful for testing - if it actually is a working Mini-Jamba model, haven't checked that yet.

@severian42
Copy link

Just checking to see if anyone has come close to getting Jamba working here. I've been working on figuring out fine-tuning and training on some new general chat Jamba models in prep for when they can be more standardized for everyone. Once we can get Jamba as a GGUF, I think it'll do some awesome stuff for all of us

https://huggingface.co/Severian/Jamba-Hercules
https://huggingface.co/Severian/Jamba-Nexus-IKM

@Any-Winter-4079
Copy link

Any update on Jamba support?

@compilade
Copy link
Collaborator

Any update on Jamba support?

I've worked on refactoring the KV cache in the past weeks to allow managing both recurrent states and Attention's KV cache at once. (See master...compilade/refactor-kv-cache) It's still a work-in-progess, but state checkpoints (necessary to avoid re-processing the whole prompt when removing the last few tokens) are implemented, but not yet handled in the server. I'll open a PR when it will be ready. I still need more time to think through the implementation (currently very busy with other things).

After that, work on specific hybrid model architectures like Jamba will be possible.

@severian42
Copy link

@compilade Thank you so much for taking this on. I have been trying on my own but failing miserably to get Jamba quantizied with llama.cpp

I have been prepping by training as many Jamba models as possible since that is more my wheelhouse

For your endeavors, could I 'Buy You a Coffee' to help support? I know this extra work isn't easy by any means

@erlebach
Copy link

erlebach commented May 3, 2024

Could somebody write about why quantizing Jamba and providing a gguf is difficult? Thanks. Gordon.

@compilade
Copy link
Collaborator

compilade commented May 4, 2024

For your endeavors, could I 'Buy You a Coffee' to help support?

@severian42 I appreciate the offer (it means a lot!), but I can't accept for now. Receiving international donations seems a bit complicated accounting-wise and I don't want to have to think about this (yet). Still nice to know you want this to succeed!

I know this extra work isn't easy by any means

Well, I don't see it as "work", more like exploring ideas. I like to be able to deeply think through some hard problems, and llama.cpp has plenty of that. :)

Could somebody write about why quantizing Jamba and providing a gguf is difficult? Thanks. Gordon.

@erlebach The main difficulty is how the state is managed; some layers (the Attention layers) will use the KV cache while others (the Mamba layers) will use recurrent states. This is what is taking the most effort/time to implement, since the API around copying, removing and allocating KV cells needs to be re-thought to support both types of cache at the same time.

I have more free time these days, so my work-in-progress of the above at master...compilade/refactor-kv-cache should advance a bit quicker than in the past weeks/month, though I'm currently working on simplifying convert-hf-to-gguf.py (#7031) to use lazy operations (#7075) to avoid having all the weights of a model in RAM during conversion. This should make testing of the conversion for big models (like Jamba, with its 100GB of bfloat16 weights) much easier and far less memory-hungry (and/or less disk-hungry if the --use-temp-file option was used).

Quantization will likely not be a problem, since it seemed to work well enough for bigger Mamba models. I don't know why people keep repeating it can't be quantized. The internal Mamba-specific stuff can't, but even in pure Mamba models it's less than ~5% of the weights, while the rest of the space is taken up by linear projections, which can be quantized.

Feel free to contribute code if you are though, you could help out @compilade which seems to be one piece of the puzzle

@nonetrix Thanks for reminding others that they too can help. (EDIT: hey, your comment was useful, you didn't need to delete it)

For examples of how to help:

  • train Jamba finetunes
    • this makes the implementation in llama.cpp more worth it
  • constructively answer to "are we there yet?" messages sent here
    • this gives me time for a more thoughtful message later, like this one
  • share ideas/code of how to manage recurrent states at the same time as the KV cache
    • this is more complicated than it sounds because it's necessary to manage sequences and some operations like copies and (sometimes partial) removal.
    • a simple data structure for this would be nice
      • In the branch I've linked before, I've implemented a tree of sequences, to allow for shared checkpoints, but simpler ways probably exist
    • a way to manage the allocation of that with different buffer sizes per layer
  • feedback on the code
    • this will be easier once I open a pull request for this.
  • test the code once it seems ready
    • my hardware is limited, so I will need help with testing. I'll make sure to announce it here once I get to this point.

@nonetrix
Copy link

nonetrix commented May 5, 2024

Feel free to contribute code if you are though, you could help out @compilade which seems to be one piece of the puzzle

@nonetrix Thanks for reminding others that they too can help. (EDIT: hey, your comment was useful, you didn't need to delete it)

No, it was somewhat mean spirited. I should have said what I said, I apologize

@erlebach
Copy link

Thank you for the response. I was simply curious since it was the first time I noticed a quantization effort take so much time. Truely, I appreciate all the hard work you guys put into this. Good luck!

@pszemraj
Copy link

@compilade @trap20 It's not perfect/SoTA, but I pretrained a small Jamba arch (900M params, 8 experts) on about 20B tokens using the stock HF modeling code if this helps any testing: https://hf.co/pszemraj/jamba-900M-v0.13-KIx2

there's a notebook on there with an inference example (interestingly, it uses only a few GB of VRAM even if you generate 10k tokens!)

@compilade
Copy link
Collaborator

compilade commented May 25, 2024

Okay, turns out I only had to put like, 2 to 3 more days of work on this and BAM it works.

As of today, in branch refactor-kv-cache, using the model from #6372 (comment), conversion works, loading works, and inference works (well, it seems to be working, at least, with coherent sentences (thank you @pszemraj for training it!)). I did not test quantization yet (except for Q8_0).

Example output from jamba-900M-v0.13-KIx2 (click to expand)
$  ./bin/main -m /srv/LLMstash/tmp/jamba-900M.bf16.gguf --temp 0 -e -p "I believe the meaning of life is" --repeat-penalty 1.2 --repeat-last-n 256 -c 16384 -n 256
Log start
main: build = 3003 (0fd13e94)
main: built with gcc (GCC) 13.2.0 for x86_64-unknown-linux-gnu
main: seed  = 1716594011
llama_model_loader: loaded meta data with 26 key-value pairs and 189 tensors from /srv/LLMstash/tmp/jamba-900M.bf16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = jamba-900M-v0.13-KIx2
llama_model_loader: - kv   2:                          jamba.block_count u32              = 12
llama_model_loader: - kv   3:                       jamba.context_length u32              = 16384
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 1024
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 4096
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,12]      = [0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0]
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 2048
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 8
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = gpt-2
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65024]   = ["<EOT>", "<META>", "<META_START>", "...
llama_model_loader: - kv  19:                  tokenizer.ggml.token_type arr[i32,65024]   = [3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  20:                      tokenizer.ggml.merges arr[str,64739]   = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "ĠĠ �...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  121 tensors
llama_model_loader: - type bf16:   68 tensors
llm_load_vocab: special tokens definition check successful ( 29/65024 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 65024
llm_load_print_meta: n_merges         = 64739
llm_load_print_meta: n_ctx_train      = 16384
llm_load_print_meta: n_embd           = 1024
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 12
llm_load_print_meta: n_rot            = 32
llm_load_print_meta: n_embd_head_k    = 32
llm_load_print_meta: n_embd_head_v    = 32
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 4096
llm_load_print_meta: n_expert         = 8
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 16384
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 2048
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 887.66 M
llm_load_print_meta: model size       = 1.67 GiB (16.19 BPW) 
llm_load_print_meta: general.name     = jamba-900M-v0.13-KIx2
llm_load_print_meta: BOS token        = 0 '<EOT>'
llm_load_print_meta: EOS token        = 0 '<EOT>'
llm_load_print_meta: UNK token        = 0 '<EOT>'
llm_load_print_meta: PAD token        = 0 '<EOT>'
llm_load_print_meta: LF token         = 133 'Ä'
llm_load_tensors: ggml ctx size =    0.09 MiB
llm_load_tensors:        CPU buffer size =  1713.16 MiB
......................................
llama_new_context_with_model: n_ctx      = 16384
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    49.34 MiB
llama_new_context_with_model: SSM state size =     1.34 MiB, R (f32):    0.21 MiB, S (f32):    1.12 MiB
llama_new_context_with_model: KV cache size  =    48.00 MiB, K (f16):   24.00 MiB, V (f16):   24.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =  1062.03 MiB
llama_new_context_with_model: graph nodes  = 621
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 2 / 4 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 256, repeat_penalty = 1.200, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 16384, n_batch = 2048, n_predict = 256, n_keep = 0


<EOT>I believe the meaning of life is not to be found in a single word, but rather as an expression of one's own feelings and thoughts.

The idea that we are all born with our bodies, whether they are human or animal, has been around for centuries. It was believed by some that it was something like a body made up of bones, which were attached to each other at birth. The most common form of this type of bone is called a "bone." This is what makes it so hard to tell if you're alive or dead. In fact, there are many different types of bones, including those that have been used for various purposes such as healing wounds, wounding wounds, etc.

In ancient times, people had a lot of teeth, and these were often very small. They could also be placed on top of their heads, where they would sit down and look at them. These were usually large, round stones, which were sometimes covered with hair. When the skin was removed from the head, the bones became more prominent, and the muscles began to grow larger.

This kind of bone was known as a "bone" because it was made out of two parts: the outermost part (the innermost portion) and the innermost part (the outermost
llama_print_timings:        load time =     252.28 ms
llama_print_timings:      sample time =     303.07 ms /   256 runs   (    1.18 ms per token,   844.68 tokens per second)
llama_print_timings: prompt eval time =     200.72 ms /     8 tokens (   25.09 ms per token,    39.86 tokens per second)
llama_print_timings:        eval time =   12516.79 ms /   255 runs   (   49.09 ms per token,    20.37 tokens per second)
llama_print_timings:       total time =   13213.95 ms /   263 tokens
Log end

Inference is CPU-only for now, because the Mamba implementation in llama.cpp is still CPU-only (ref: #6758).

To convert, I used

$ python3 convert-hf-to-gguf.py /srv/LLMstash/src/jamba-900M-v0.13-KIx2/ --outfile /srv/LLMstash/tmp/jamba-900M.{ftype}.gguf --outtype auto

(@severian42 note that convert-hf-to-gguf.py doesn't yet support loading 4-bit BitsAndBytes models (yet (I might work on fixing this eventually)). In the meantime, dequantizing to a 16-bit float type (either F16 or BF16) will be necessary.)

I'm going to open a PR very soon (in a few hours or tomorrow), I just need to write it up (there are quite a lot of lines changed and I want to explain).

For the impatient, the code is in https://github.com/ggerganov/llama.cpp/tree/compilade/refactor-kv-cache.

@severian42
Copy link

Incredible! Thank you so much for putting in the work to get this running and updating us with the news. I will give this a try this weekend and report back. So excited to try this! Thanks again for using your smarts to further the open source LLM world. We owe a big one 💪

@compilade
Copy link
Collaborator

There is still more work I need to put into this.
I've got inference working, but things that are not yet done are:

  • state saving and reloading to and from a session file
  • proper state checkpoint handling in the server example
    • to properly backtrack with cached state
  • same, but in all the other examples

Building this will (currently) output lots of warnings because I've renamed many functions related to KV cache management, and then deprecated the old names, but I did not yet update their usages in the various examples.

Finishing this will take a few days still, but I think I will still open a PR tomorrow.

I think I won't change anything about how the GGUFs of Jamba are made, unless I've unexpectedly messed something up in the conversion code.

@compilade compilade linked a pull request May 25, 2024 that will close this issue
15 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants