-
Notifications
You must be signed in to change notification settings - Fork 11.5k
llama : Support llama 4 text-only #12791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I was trying to reproduce your setup locally, and here’s what I did: cd /
git clone https://huggingface.co/ngxson/TEST-Tiny-Llama4
cd /ws/llama.cpp
gh pr checkout 12791
# build
...
python convert_hf_to_gguf.py /TEST-Tiny-Llama4/
./build/bin/llama-cli -m /TEST-Tiny-Llama4-4x301M-F16.gguf -ngl 999 The model metadata seems to load correctly:
However, it fails during tensor loading with the following error:
Am I missing a step, or is there something wrong with the converted model? |
@yeahdongcn the Please note that I still don't get 1-to-1 logits match with transformers, I think my impl still missing some small things (not the attn_scale, since for n_ctx < 8192 it's always 1.0) |
Ok great new, I archived logits matching with transformers output (diff within an acceptable range): Transformers:
llama.cpp:
|
Our Mac studio is literally commuting by train to the office, so unfortunately can't test with the big model rn.. |
I can test the Scout - anything specific you want to check? |
Yes can you please see if it counts the number of R in strawberry or raspberry correctly now? Thanks! I also want to see if it can generate past 4k tokens. The current implementation allows generate upto 8k context, after that we will need the chunked mask and attn scale. |
Ok I found the problem why the activations don't match in my comment above. Turns out, I forgot the apply MoE router weights to gate. It should match almost 100% for now (minus some floating point inaccuracy):
Forgot to mention, here is the command to generate the activations:
On transformers:
|
src/llama-graph.cpp
Outdated
if (weight_before_ffn) { | ||
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); | ||
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] | ||
cur = ggml_mul(ctx0, repeated, weights); | ||
cb(cur, "ffn_moe_weighted", il); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is the best way to repeat a dim. What I want to do here is to turn from:
[n_embd, 1, n_tokens]
to:
[n_embd, n_expert_used, n_tokens]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, usage of ggml_new_tensor_xd
should be avoided as much as possible because it can increase the required compute graph memory (i.e. allocate separate tensors for each layer, without reusing the memory). However, in this specific case, I can't think of an alternative way to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you rely on broadcasting support in ggml_mul
to do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to require:
cur = ggml_mul(ctx0, [n_embd, 1, n_tokens], [1, n_expert_used, n_tokens]);
Not sure if we support this type of broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried ggml_mul directly but it failed the can_repeat check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Um yeah we don't support broadcasting in both tensors at the same time. In that case we should add a ggml_repeat_4d
that just takes the desired dimension directly, instead of a tensor, to avoid creating a new tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative approach for this particular case, we are currently doing:
cur = repeat(cur)
cur = mul(cur, weights)
gate = mm(cur, gate_w)
gate = silu(gate)
up = mm(cur, up_w)
But this also work, basically replace one repeat
with 2 mul
gate = mul(mm(cur, gate_w), weights)
gate = silu(gate)
up = mul(mm(cur, up_w), weights)
But I think 2 times ggml_mul
should be slower than 1 repeat and 1 mul
Ok both attn temp and attn chunked mask should work now, here is the visualization for chunk of 5 tokens:
|
src/llama-model.cpp
Outdated
// temperature tuning | ||
ggml_tensor * inp_attn_scale = nullptr; | ||
if (arch == LLM_ARCH_LLAMA4) { | ||
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); | ||
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); | ||
ggml_set_input(inp_attn_scale); | ||
inp->attn_scale = inp_attn_scale; | ||
res->add_input(std::move(inp)); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this in a helper function for consistency:
// temperature tuning | |
ggml_tensor * inp_attn_scale = nullptr; | |
if (arch == LLM_ARCH_LLAMA4) { | |
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); | |
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); | |
ggml_set_input(inp_attn_scale); | |
inp->attn_scale = inp_attn_scale; | |
res->add_input(std::move(inp)); | |
} | |
// temperature tuning | |
ggml_tensor * inp_attn_scale = build_inp_attn_scale(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in af1968c
src/llama-model.cpp
Outdated
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); | ||
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); | ||
// hack: we use SWA to store the chunked attn mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, SWA -> AUX makes sense.
Btw, we should soon implement actual SWA / chunked attention that uses less memory. It shouldn't be a big change and will improve memory usage significantly for such models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rename makes quite more changes than I expected, so I think I'll do it in another PR to test it more thoroughly. Here I'll only edit my comment to make it more clear that I'm using the "swa" variable to store the chunked mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do this after having the logits matching automated test. The problem is that changing name n_swa
to n_pattern_aux
make the code checking if (n_swa) then using SWA
becomes invalid. It should now become if (n_pattern_aux && is_swa) then use SWA
I think it's better to add an enum called llama_mask_aux
having 3 values: none, swa, chunked ; so that the code will become more clear
Don't know if this is helpful, but both Llama 4 models get the same answer on Lambda Cloud. llama-4-maverick-17b-128e-instruct-fp8: llama-4-scout-17b-16e-instruct: |
Not sure if it's relevant to this PR, but for some reason conversion is taking way longer with scout than with other ~70-111b models.. been going for over an hour loading tensors and nothing has been written to disk yet, a 70b model took 10 min to convert, an 111b model took 15 min to convert Converting to bf16 Did you get similar @justinjja2 ? (it is definitely possible that it's because my CPU is busy with other stuff, but don't think it usually makes it this much slower, can test again later with nothing else running) |
@bartowski1182 It (this model) typically takes 20-30 mn when I convert to Q8_0, yes it's slow because the original model has moe's gate and up tensor "fused" into one, the conversion script need to split then transpose it, so you should see a CPU spike for each layer converted |
Alright! Since my last commits are mostly comment changes, there is no need to re-review them. Will merge this PR once the CI is green 🚀 (except for the |
Would it be possible for you to upload the gguf f16 model ? |
I'm converting to GGUF now and noticing the same as @bartowski1182, it seems to take an unusually long time to convert, probably due to some of the tensors being 16x the expected size (?) |
I'm uploading an f16 on my side just in case, should be available |
Btw, if someone has 2TB of disk space and > 400GB RAM, feel free to test the Llama 4 Maverick and share your result here |
name_up = name.replace("gate_up_proj", "up_proj.weight") | ||
name_gate = name.replace("gate_up_proj", "gate_proj.weight") | ||
dim_half = data_torch.shape[-1] // 2 | ||
gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy evaluation doesn't support splitting yet, so this will always eagerly evaluate (and so it will take more RAM than ideal during conversion).
This may or may not explain the conversion slowness others are noticing.
This can be fixed in gguf/gguf-py/lazy.py
by handling tuples of tensors as output values. I have the necessary changes somewhere, I'll open a PR once I find them.
(EDIT: see #12809)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off-topic question: is it possible to somehow extend LazyTorchTensor
to load a tensor remotely? FYI huggingface backend supports byte range, so an idea could be to read the tensor one by one completely on RAM, without having to download them to disk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off-topic question: is it possible to somehow extend
LazyTorchTensor
to load a tensor remotely? FYI huggingface backend supports byte range, so an idea could be to read the tensor one by one completely on RAM, without having to download them to disk
Huh, I never thought of it, but yes, technically this should totally be possible. Lazy tensors only needs the name, shape and type of the tensor for the fake tensors, and then a way to turn the original fake tensors into real tensors.
The hardest part of this wouldn't necessarily be the lazy tensors, but how the remote paths would be specified and how it would interact with the default output path and default name of the output file, and how the tensors would be enumerated and how the config file and the tokenizer would be fetched.
There's a lot of tokenizer-related code which assumes local files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can rely on AutoTokenizer.from_pretrained, which will download tokenizer files to a temporary directory. Will have a look on how it works.
We can alternatively rely on huggingface_hub.download() which accepts a pattern of file name to download (so for example, we can disallow downloading safetensors)
In my case, loading safetensors remotely can be very useful. I couldn't test the 409B maverick model as it requires 1.5TB in total to store both HF model + gguf, but HF space only provides at max 1TB of storage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngxson I just tested Maverick -sadly it doesn't work on BF16 conversion - it's cause "interleave_moe_layer_step": 2,
so every 2nd layer / odd layer is MoE, whilst the rest are FFN.
Error:
INFO:hf-to-gguf:gguf: loading model part 'model-00002-of-00055.safetensors'
INFO:hf-to-gguf:blk.1.ffn_down_exps.weight, torch.bfloat16 --> BF16, shape = {8192, 5120, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00003-of-00050.safetensors'
INFO:hf-to-gguf:blk.1.ffn_down_exps.weight, torch.bfloat16 --> BF16, shape = {8192, 5120, 16}
Traceback (most recent call last):
ValueError: Duplicated tensor name 'blk.1.ffn_down_exps.weight'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just downloaded unsloth/Llama-4-Maverick-17B-128E-Instruct
from HF. Everything looks good so far. (On a machine with 2 TB RAM and 13 TB SSD.)
It took me about 50 minutes to reach this point.
INFO:hf-to-gguf:gguf: loading model part 'model-00030-of-00055.safetensors'
INFO:hf-to-gguf:blk.27.ffn_gate_exps.weight, torch.bfloat16 --> F16, shape = {5120, 8192, 128}
INFO:hf-to-gguf:blk.27.ffn_up_exps.weight, torch.bfloat16 --> F16, shape = {5120, 8192, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00031-of-00055.safetensors'
INFO:hf-to-gguf:blk.27.ffn_down_exps.weight, torch.bfloat16 --> F16, shape = {8192, 5120, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00032-of-00055.safetensors'
INFO:hf-to-gguf:blk.29.ffn_gate_exps.weight, torch.bfloat16 --> F16, shape = {5120, 8192, 128}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving this comment here for viz: we discussed via DM and turns out Daniel was using wrong directory 😂
+1 reason to support converting HF --> gguf without downloading to disk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 reason to support converting HF --> gguf without downloading to disk
This feature will be 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got past the ValueError: Duplicated tensor name 'blk.1.ffn_down_exps.weight'
error, thanks!:
INFO:gguf.gguf_writer:Writing the following files:
INFO:gguf.gguf_writer://Volumes/storage 1/models/LLaMA-4-Maverick-17B.gguf: n_tensors = 531, total_size = 801.5G
Writing: 0%| | 0.00/801G [00:00<?, ?byte/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Writing: 1%|█▏
Now I wait!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: More Information on Llama 4 Maverick
🛠️ Model Conversion
root@bb22ebf4525a:/ws# time python convert_hf_to_gguf.py model_path
...
INFO:hf-to-gguf:Set model quantization version
INFO:gguf.gguf_writer:Writing the following files:
INFO:gguf.gguf_writer:/model_path/840ed22d9bc7731246bc119cca026a48a0ff8ec6-128x17B-840ed22d9bc7731246bc119cca026a48a0ff8ec6-F16.gguf: n_tensors = 531, total_size = 801.5G
...
real 302m33.600s
user 285m44.179s
sys 113m0.490s
🔧 Quantization
root@bb22ebf4525a:/ws# time ./build/bin/llama-quantize model_path llama4_maverick_q4_k_m.gguf Q4_K_M
...
llama_model_quantize_impl: model size = 764328.14 MB
llama_model_quantize_impl: quant size = 231508.31 MB
main: quantize time = 1372408.18 ms
main: total time = 1372408.18 ms
./build/bin/llama-quantize Q4_K_M 55951.43s user 4741.92s system 4421% cpu 22:52.55 total
🧪 Tested with MUSA backend:
> Hi
Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?
> How many 'R' in strawberry?
There are 2 'R's in the word "strawberry".
>
Tested with the |
Is there any resolution why the conversion is slow? Maybe related: |
Just saw @compilade has answered this already:
|
* master: (123 commits) cuda : add f32 to bf16 copy op (ggml-org#12806) llava: improve clip_ctx destructor to not memleak load_image_size (ggml-org#12834) llama : fix FA when KV cache is not used (i.e. embeddings) (ggml-org#12825) server : fix thread.join() on exit (ggml-org#12831) llava: add more helper functions to check projector types in clip context (ggml-org#12824) arg : Including limits file on AIX (ggml-org#12822) server : webui : Improve Chat Input with Auto-Sizing Textarea (ggml-org#12785) Revert "sycl:remove redundant memcopy in function ggml_backend_sycl_buffer_set_tensor" (ggml-org#12812) gguf-py : support lazy tensor splitting (ggml-org#12809) llama : Support llama 4 text-only (ggml-org#12791) opencl: better identify Adreno GPU (ggml-org#12760) hellaswag: display estimated score confidence interval (ggml-org#12797) cuda : fix HIP and MUSA BF16 (#0) sync : ggml ggml : simplify Arm fp16 CPU logic (ggml/1177) CUDA: don't convert BF16 weights to FP32 (ggml/1174) cpu: move all the operators into a separate c++ file (except mul_mat) (ggml/1167) sycl: remove redundant memcopy in function ggml_backend_sycl_buffer_set_tensor (ggml-org#12734) ci : no curl on ggml-ci (ggml-org#12796) cmake : enable curl by default (ggml-org#12761) ... # Conflicts: # common/arg.cpp # common/common.cpp # common/common.h
* llama4 conversion * initial support, no chat template * clean up a bit * fix tokenizer conversion * correct hparams * try this * fix shexp * ffn_inp_normed * chat template * clean up model conversion * add_bos * add scale_before_ffn * fix order * weight_before_ffn * llm_graph_input_attn_temp * add chunk attn mask * build_inp_attn_scale() * add comment about ggml_repeat * clarify comments * fix build
Resolves #12774
This PR targets Llama-4-Scout-17B-16E-Instruct. I don't (yet?) have a powerful enough system to work with bigger model.
But Son, you are GPU-poor, how can you test a model that big? Well, indeed, I started the development with this smaller random weight (I planned to train it with this dataset but seems like TRL does not yet support llama 4?)
I keep a copy of
transformer
library locally and add bunch ofprint()
statements to see the intermediate activations. On llama.cpp side, I useeval-callback
made by @phymbert to print out the intermediate results. The goal is to match these activations (and thus eventually logits match) before trying the big model.Then once I'm happy with the result, I can "borrow" a big machine in our office. This is the output from Llama-4-Scout-17B-16E-Instruct quantized to Q8_0, running on a mac studio M3 Ultra 512GB RAM:
Performance is good too:
However, it fails on the ultimate AGI test (not sure if that's a problem with the model or with my implementation)
Design considerations
The diff between llama 3 vs 4 are described in this comment: #12774 (comment)
llama4
is added because some new hparams are added. This is also for future-proof as we can easily modify default hparams / cgraph even after this PR get merged, without introducing a breaking changellama
arch, with some modifications in MoEattn_scaling
(used byattn_temperature_tuning
) will need to take a different approach than pytorch or mlx implementation: due to missingggml_cast
between F32 <--> I32 and missingggml_floor
, we will need to calculate theattn_scaling
array (which will have the same dim asinp_pos
) on CPU, and use it in cgraph as an input nodeTODO:
attn_scaling