Skip to content

llama : Support llama 4 text-only #12791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 7, 2025
Merged

llama : Support llama 4 text-only #12791

merged 21 commits into from
Apr 7, 2025

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Apr 7, 2025

Resolves #12774

This PR targets Llama-4-Scout-17B-16E-Instruct. I don't (yet?) have a powerful enough system to work with bigger model.

But Son, you are GPU-poor, how can you test a model that big? Well, indeed, I started the development with this smaller random weight (I planned to train it with this dataset but seems like TRL does not yet support llama 4?)

I keep a copy of transformer library locally and add bunch of print() statements to see the intermediate activations. On llama.cpp side, I use eval-callback made by @phymbert to print out the intermediate results. The goal is to match these activations (and thus eventually logits match) before trying the big model.

Then once I'm happy with the result, I can "borrow" a big machine in our office. This is the output from Llama-4-Scout-17B-16E-Instruct quantized to Q8_0, running on a mac studio M3 Ultra 512GB RAM:

Hi! How are you today? Is there something I can help you with, or would you like to chat?

> who are u
I'm an AI assistant designed to simulate conversation, answer questions, and provide information on a wide range of topics. I'm a computer program that uses natural language processing (NLP) to understand and respond to human input.

I don't have a personal identity or emotions like humans do, but I'm here to help and chat with you! I'm often referred to as a "chatbot" or "virtual assistant." I can provide information, answer questions, and even engage in simple conversations.

Would you like to know more about me or is there something specific you'd like to talk about?

>

Performance is good too:

llama_perf_sampler_print:    sampling time =       6.94 ms /   126 runs   (    0.06 ms per token, 18147.77 tokens per second)
llama_perf_context_print:        load time =    2948.98 ms
llama_perf_context_print: prompt eval time =     431.23 ms /    23 tokens (   18.75 ms per token,    53.34 tokens per second)
llama_perf_context_print:        eval time =    4951.49 ms /   136 runs   (   36.41 ms per token,    27.47 tokens per second)
llama_perf_context_print:       total time =  705735.53 ms /   159 tokens

However, it fails on the ultimate AGI test (not sure if that's a problem with the model or with my implementation)

> how many r's in the word raspberry
There are 2 R's in the word "raspberry".

Design considerations

The diff between llama 3 vs 4 are described in this comment: #12774 (comment)

  • A new arch llama4 is added because some new hparams are added. This is also for future-proof as we can easily modify default hparams / cgraph even after this PR get merged, without introducing a breaking change
  • It reuse the same cgraph with llama arch, with some modifications in MoE
  • attn_scaling (used by attn_temperature_tuning) will need to take a different approach than pytorch or mlx implementation: due to missing ggml_cast between F32 <--> I32 and missing ggml_floor, we will need to calculate the attn_scaling array (which will have the same dim as inp_pos) on CPU, and use it in cgraph as an input node

TODO:

  • clean up, read metadata from GGUF
  • add attn_scaling
  • add chunked mask

@github-actions github-actions bot added the python python script changes label Apr 7, 2025
@yeahdongcn
Copy link
Contributor

yeahdongcn commented Apr 7, 2025

I was trying to reproduce your setup locally, and here’s what I did:

cd /
git clone https://huggingface.co/ngxson/TEST-Tiny-Llama4
cd /ws/llama.cpp
gh pr checkout 12791
# build
...
python convert_hf_to_gguf.py /TEST-Tiny-Llama4/
./build/bin/llama-cli -m /TEST-Tiny-Llama4-4x301M-F16.gguf -ngl 999

The model metadata seems to load correctly:

llama_model_loader: loaded meta data with 32 key-value pairs and 70 tensors from ./tinyllama4.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama4
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = TEST Tiny Llama4
llama_model_loader: - kv   3:                         general.size_label str              = 4x301M
llama_model_loader: - kv   4:                         llama4.block_count u32              = 6
llama_model_loader: - kv   5:                      llama4.context_length u32              = 8192
llama_model_loader: - kv   6:                    llama4.embedding_length u32              = 512
llama_model_loader: - kv   7:                 llama4.feed_forward_length u32              = 16384
...

However, it fails during tensor loading with the following error:

load_tensors: loading model tensors, this can take a while... (mmap = true)
llama_model_load: error loading model: missing tensor 'blk.0.ffn_gate_inp.weight'
llama_model_load_from_file_impl: failed to load model
common_init_from_params: failed to load model './tinyllama4.gguf'
main: error: unable to load model

Am I missing a step, or is there something wrong with the converted model?

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

@yeahdongcn the interleave_moe_layer_step was using a hard-coded for Scout, I updated it to read from gguf in last commit.

Please note that I still don't get 1-to-1 logits match with transformers, I think my impl still missing some small things (not the attn_scale, since for n_ctx < 8192 it's always 1.0)

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Ok great new, I archived logits matching with transformers output (diff within an acceptable range):

Transformers:

l_out 0 torch.Size([1, 2, 512]) 
 tensor([[[ 0.4376, -0.2017,  1.2169,  ..., -1.2795, -1.6025,  0.7410],
         [ 1.1159, -0.9820,  2.1374,  ..., -0.7017,  0.5766,  1.0569]]])
l_out 1 torch.Size([1, 2, 512]) 
 tensor([[[ 1.0468, -0.6761,  0.4704,  ..., -0.0632, -2.9443,  1.8128],
         [ 0.1152, -1.7986,  0.4008,  ...,  0.3008, -0.8128,  2.2085]]])
l_out 4 torch.Size([1, 2, 512]) 
 tensor([[[ 2.4727,  1.1518,  1.3572,  ..., -3.5305, -0.9363,  1.3458],
         [ 0.7240,  1.4138,  2.3581,  ..., -1.5514,  0.6930,  1.9394]]])
output_norm torch.Size([1, 2, 512])  (only showing the last output row here)
         [ 0.0821,  0.1869,  0.2380,  ..., -0.0081, -0.0036,  0.1431]]])

llama.cpp:

ggml_debug:                  l_out-0 = (f32)        ADD(ffn_out-0{512, 2, 1, 1}, ffn_inp-0{512, 2, 1, 1}}) = {512, 2, 1, 1}

[
 [      0.4376,      -0.2017,       1.2169, ...,      -1.2800,      -1.6023,       0.7408],
 [      1.1159,      -0.9818,       2.1377, ...,      -0.7020,       0.5768,       1.0564],
],
ggml_debug:                  l_out-1 = (f32)        ADD(ffn_moe_out_merged-1{512, 2, 1, 1}, ffn_inp-1{512, 2, 1, 1}}) = {512, 2, 1, 1}
[
 [
  [      1.1061,      -0.6223,       0.4975, ...,      -0.0365,      -2.9750,       1.7490],
  [      0.2483,      -1.7931,       0.3854, ...,       0.2194,      -0.8326,       2.2311],
 ],
]
ggml_debug:                  l_out-4 = (f32)        ADD(ffn_out-4{512, 2, 1, 1}, ffn_inp-4{512, 2, 1, 1}}) = {512, 2, 1, 1}
[
 [
  [      2.5240,       1.1264,       1.3811, ...,      -3.7016,      -0.9681,       1.2990],
  [      0.7606,       1.2619,       2.2235, ...,      -1.6872,       0.7537,       2.0083],
 ],
]
ggml_debug:              result_norm = (f32)        MUL(norm{512, 1, 1, 1}, output_norm.weight{512, 1, 1, 1}}) = {512, 1, 1, 1}
[
 [
  [      0.0858,       0.1796,       0.2276, ...,      -0.0093,      -0.0037,       0.1489],
 ],
]

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Our Mac studio is literally commuting by train to the office, so unfortunately can't test with the big model rn..

@ggerganov
Copy link
Member

I can test the Scout - anything specific you want to check?

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Yes can you please see if it counts the number of R in strawberry or raspberry correctly now? Thanks!

I also want to see if it can generate past 4k tokens. The current implementation allows generate upto 8k context, after that we will need the chunked mask and attn scale.

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Ok I found the problem why the activations don't match in my comment above. Turns out, I forgot the apply MoE router weights to gate. It should match almost 100% for now (minus some floating point inaccuracy):

ggml_debug:                  l_out-0 = (f32)        ADD(ffn_out-0{512, 2, 1, 1}, ffn_inp-0{512, 2, 1, 1}}) = {512, 2, 1, 1}

[
 [
  [      0.4376,      -0.2017,       1.2169, ...,      -1.2800,      -1.6023,       0.7408],
  [      1.1159,      -0.9818,       2.1377, ...,      -0.7020,       0.5768,       1.0564],
 ],
]
ggml_debug:                  l_out-1 = (f32)        ADD(ffn_moe_out_merged-1{512, 2, 1, 1}, ffn_inp-1{512, 2, 1, 1}}) = {512, 2, 1, 1}
[
 [
  [      1.0473,      -0.6755,       0.4704, ...,      -0.0631,      -2.9447,       1.8131],
  [      0.1157,      -1.7980,       0.4012, ...,       0.3008,      -0.8128,       2.2083],
 ],
]
ggml_debug:                  l_out-4 = (f32)        ADD(ffn_out-4{512, 2, 1, 1}, ffn_inp-4{512, 2, 1, 1}}) = {512, 2, 1, 1}
[
 [
  [      2.4755,       1.1525,       1.3574, ...,      -3.5293,      -0.9357,       1.3466],
  [      0.7257,       1.4144,       2.3586, ...,      -1.5513,       0.6944,       1.9391],
 ],
]
ggml_debug:              result_norm = (f32)        MUL(norm{512, 1, 1, 1}, output_norm.weight{512, 1, 1, 1}}) = {512, 1, 1, 1}
[
 [
  [      0.0822,       0.1869,       0.2380, ...,      -0.0081,      -0.0036,       0.1431],
 ],
]

Forgot to mention, here is the command to generate the activations:

llama-eval-callback -m ../models/TEST-Tiny-Llama4/model.gguf -n 1 -p "Hi" --top-k 1 > eval.log

On transformers:

from transformers import Llama4TextConfig, Llama4VisionConfig, Llama4Config, Llama4ForConditionalGeneration, AutoTokenizer, AutoProcessor, GenerationConfig

model = Llama4ForConditionalGeneration.from_pretrained("/Users/ngxson/work/models/TEST-Tiny-Llama4")
processor = AutoProcessor.from_pretrained("/Users/ngxson/work/models/TEST-Tiny-Llama4")

print(model)
input_ids = processor(text="Hi", return_tensors="pt").input_ids

print(input_ids)
output = model.generate(input_ids=input_ids, max_new_tokens=1, generation_config=GenerationConfig(do_sample=False))

print(output)
print(processor.decode(output[0], skip_special_tokens=True))

Comment on lines 878 to 883
if (weight_before_ffn) {
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
cur = ggml_mul(ctx0, repeated, weights);
cb(cur, "ffn_moe_weighted", il);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the best way to repeat a dim. What I want to do here is to turn from:

[n_embd, 1, n_tokens]

to:

[n_embd, n_expert_used, n_tokens]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, usage of ggml_new_tensor_xd should be avoided as much as possible because it can increase the required compute graph memory (i.e. allocate separate tensors for each layer, without reusing the memory). However, in this specific case, I can't think of an alternative way to do this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you rely on broadcasting support in ggml_mul to do this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to require:

cur = ggml_mul(ctx0, [n_embd, 1, n_tokens], [1, n_expert_used, n_tokens]);

Not sure if we support this type of broadcast.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried ggml_mul directly but it failed the can_repeat check

Copy link
Member

@slaren slaren Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um yeah we don't support broadcasting in both tensors at the same time. In that case we should add a ggml_repeat_4d that just takes the desired dimension directly, instead of a tensor, to avoid creating a new tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach for this particular case, we are currently doing:

cur = repeat(cur)
cur = mul(cur, weights)
gate = mm(cur, gate_w)
gate = silu(gate)
up = mm(cur, up_w)

But this also work, basically replace one repeat with 2 mul

gate = mul(mm(cur, gate_w), weights)
gate = silu(gate)
up = mul(mm(cur, up_w), weights)

But I think 2 times ggml_mul should be slower than 1 repeat and 1 mul

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Ok both attn temp and attn chunked mask should work now, here is the visualization for chunk of 5 tokens:

x...................
xx..................
xxx.................
xxxx................
xxxxx...............
.....x..............
.....xx.............
.....xxx............
.....xxxx...........
.....xxxxx..........
..........x.........
..........xx........
..........xxx.......
..........xxxx......
..........xxxxx.....
...............x....
...............xx...
...............xxx..
...............xxxx.
...............xxxxx

@ngxson ngxson marked this pull request as ready for review April 7, 2025 17:34
Comment on lines 4279 to 4288
// temperature tuning
ggml_tensor * inp_attn_scale = nullptr;
if (arch == LLM_ARCH_LLAMA4) {
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
ggml_set_input(inp_attn_scale);
inp->attn_scale = inp_attn_scale;
res->add_input(std::move(inp));
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this in a helper function for consistency:

Suggested change
// temperature tuning
ggml_tensor * inp_attn_scale = nullptr;
if (arch == LLM_ARCH_LLAMA4) {
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
ggml_set_input(inp_attn_scale);
inp->attn_scale = inp_attn_scale;
res->add_input(std::move(inp));
}
// temperature tuning
ggml_tensor * inp_attn_scale = build_inp_attn_scale();

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in af1968c

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
// hack: we use SWA to store the chunked attn mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, SWA -> AUX makes sense.

Btw, we should soon implement actual SWA / chunked attention that uses less memory. It shouldn't be a big change and will improve memory usage significantly for such models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rename makes quite more changes than I expected, so I think I'll do it in another PR to test it more thoroughly. Here I'll only edit my comment to make it more clear that I'm using the "swa" variable to store the chunked mask

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this after having the logits matching automated test. The problem is that changing name n_swa to n_pattern_aux make the code checking if (n_swa) then using SWA becomes invalid. It should now become if (n_pattern_aux && is_swa) then use SWA

I think it's better to add an enum called llama_mask_aux having 3 values: none, swa, chunked ; so that the code will become more clear

@justinjja2
Copy link

how many r's in the word raspberry

Don't know if this is helpful, but both Llama 4 models get the same answer on Lambda Cloud.

llama-4-maverick-17b-128e-instruct-fp8:
There are 2 R's in the word "raspberry".

llama-4-scout-17b-16e-instruct:
There are 2 R's in the word "raspberry".

@bartowski1182
Copy link
Contributor

bartowski1182 commented Apr 7, 2025

Not sure if it's relevant to this PR, but for some reason conversion is taking way longer with scout than with other ~70-111b models..

been going for over an hour loading tensors and nothing has been written to disk yet, a 70b model took 10 min to convert, an 111b model took 15 min to convert

Converting to bf16

Did you get similar @justinjja2 ?

(it is definitely possible that it's because my CPU is busy with other stuff, but don't think it usually makes it this much slower, can test again later with nothing else running)

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

@bartowski1182 It (this model) typically takes 20-30 mn when I convert to Q8_0, yes it's slow because the original model has moe's gate and up tensor "fused" into one, the conversion script need to split then transpose it, so you should see a CPU spike for each layer converted

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Alright! Since my last commits are mostly comment changes, there is no need to re-review them.

Will merge this PR once the CI is green 🚀 (except for the build-linux-cross which is not relevant to this PR)

@ngxson ngxson changed the title llama : Support llama 4 text-only (WIP) llama : Support llama 4 text-only Apr 7, 2025
@nikhil-arm
Copy link

@bartowski1182 It (this model) typically takes 20-30 mn when I convert to Q8_0, yes it's slow because the original model has moe's gate and up tensor "fused" into one, the conversion script need to split then transpose it, so you should see a CPU spike for each layer converted

Would it be possible for you to upload the gguf f16 model ?

@ngxson ngxson merged commit 1466621 into ggml-org:master Apr 7, 2025
51 of 54 checks passed
@ddh0
Copy link
Contributor

ddh0 commented Apr 7, 2025

I'm converting to GGUF now and noticing the same as @bartowski1182, it seems to take an unusually long time to convert, probably due to some of the tensors being 16x the expected size (?)

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

I'm uploading an f16 on my side just in case, should be available in 10 mn now: https://huggingface.co/ngxson/Llama-4-Scout-17B-16E-Instruct-GGUF

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 7, 2025

Btw, if someone has 2TB of disk space and > 400GB RAM, feel free to test the Llama 4 Maverick and share your result here

name_up = name.replace("gate_up_proj", "up_proj.weight")
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
dim_half = data_torch.shape[-1] // 2
gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2)
Copy link
Collaborator

@compilade compilade Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy evaluation doesn't support splitting yet, so this will always eagerly evaluate (and so it will take more RAM than ideal during conversion).

This may or may not explain the conversion slowness others are noticing.

This can be fixed in gguf/gguf-py/lazy.py by handling tuples of tensors as output values. I have the necessary changes somewhere, I'll open a PR once I find them.

(EDIT: see #12809)

Copy link
Collaborator Author

@ngxson ngxson Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic question: is it possible to somehow extend LazyTorchTensor to load a tensor remotely? FYI huggingface backend supports byte range, so an idea could be to read the tensor one by one completely on RAM, without having to download them to disk

Copy link
Collaborator

@compilade compilade Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic question: is it possible to somehow extend LazyTorchTensor to load a tensor remotely? FYI huggingface backend supports byte range, so an idea could be to read the tensor one by one completely on RAM, without having to download them to disk

Huh, I never thought of it, but yes, technically this should totally be possible. Lazy tensors only needs the name, shape and type of the tensor for the fake tensors, and then a way to turn the original fake tensors into real tensors.

The hardest part of this wouldn't necessarily be the lazy tensors, but how the remote paths would be specified and how it would interact with the default output path and default name of the output file, and how the tensors would be enumerated and how the config file and the tokenizer would be fetched.

There's a lot of tokenizer-related code which assumes local files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rely on AutoTokenizer.from_pretrained, which will download tokenizer files to a temporary directory. Will have a look on how it works.

We can alternatively rely on huggingface_hub.download() which accepts a pattern of file name to download (so for example, we can disallow downloading safetensors)

In my case, loading safetensors remotely can be very useful. I couldn't test the 409B maverick model as it requires 1.5TB in total to store both HF model + gguf, but HF space only provides at max 1TB of storage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I just tested Maverick -sadly it doesn't work on BF16 conversion - it's cause "interleave_moe_layer_step": 2, so every 2nd layer / odd layer is MoE, whilst the rest are FFN.

Error:

INFO:hf-to-gguf:gguf: loading model part 'model-00002-of-00055.safetensors'
INFO:hf-to-gguf:blk.1.ffn_down_exps.weight,   torch.bfloat16 --> BF16, shape = {8192, 5120, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00003-of-00050.safetensors'
INFO:hf-to-gguf:blk.1.ffn_down_exps.weight,   torch.bfloat16 --> BF16, shape = {8192, 5120, 16}
Traceback (most recent call last):
ValueError: Duplicated tensor name 'blk.1.ffn_down_exps.weight'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just downloaded unsloth/Llama-4-Maverick-17B-128E-Instruct from HF. Everything looks good so far. (On a machine with 2 TB RAM and 13 TB SSD.)

It took me about 50 minutes to reach this point.

INFO:hf-to-gguf:gguf: loading model part 'model-00030-of-00055.safetensors'
INFO:hf-to-gguf:blk.27.ffn_gate_exps.weight,  torch.bfloat16 --> F16, shape = {5120, 8192, 128}
INFO:hf-to-gguf:blk.27.ffn_up_exps.weight,    torch.bfloat16 --> F16, shape = {5120, 8192, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00031-of-00055.safetensors'
INFO:hf-to-gguf:blk.27.ffn_down_exps.weight,  torch.bfloat16 --> F16, shape = {8192, 5120, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00032-of-00055.safetensors'
INFO:hf-to-gguf:blk.29.ffn_gate_exps.weight,  torch.bfloat16 --> F16, shape = {5120, 8192, 128}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving this comment here for viz: we discussed via DM and turns out Daniel was using wrong directory 😂

+1 reason to support converting HF --> gguf without downloading to disk

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 reason to support converting HF --> gguf without downloading to disk

This feature will be 🔥

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got past the ValueError: Duplicated tensor name 'blk.1.ffn_down_exps.weight' error, thanks!:

INFO:gguf.gguf_writer:Writing the following files:
INFO:gguf.gguf_writer://Volumes/storage 1/models/LLaMA-4-Maverick-17B.gguf: n_tensors = 531, total_size = 801.5G
Writing:   0%|                                                                                                     | 0.00/801G [00:00<?, ?byte/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Writing:   1%|█▏    

Now I wait!

Copy link
Contributor

@yeahdongcn yeahdongcn Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: More Information on Llama 4 Maverick

🛠️ Model Conversion

root@bb22ebf4525a:/ws# time python convert_hf_to_gguf.py model_path
...
INFO:hf-to-gguf:Set model quantization version  
INFO:gguf.gguf_writer:Writing the following files:  
INFO:gguf.gguf_writer:/model_path/840ed22d9bc7731246bc119cca026a48a0ff8ec6-128x17B-840ed22d9bc7731246bc119cca026a48a0ff8ec6-F16.gguf: n_tensors = 531, total_size = 801.5G  
...
real    302m33.600s  
user    285m44.179s  
sys     113m0.490s  

🔧 Quantization

root@bb22ebf4525a:/ws# time ./build/bin/llama-quantize model_path llama4_maverick_q4_k_m.gguf Q4_K_M
...
llama_model_quantize_impl: model size  = 764328.14 MB  
llama_model_quantize_impl: quant size  = 231508.31 MB  

main: quantize time = 1372408.18 ms  
main:    total time = 1372408.18 ms  
./build/bin/llama-quantize   Q4_K_M  55951.43s user 4741.92s system 4421% cpu 22:52.55 total  

🧪 Tested with MUSA backend:

> Hi    
Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?

> How many 'R' in strawberry?
There are 2 'R's in the word "strawberry".

> 

@yeahdongcn
Copy link
Contributor

Tested with the MUSA backend — the model you provided loads successfully and runs inference correctly on MTGPU. 👍

@csabakecskemeti
Copy link
Contributor

csabakecskemeti commented Apr 8, 2025

Is there any resolution why the conversion is slow?

Maybe related:
I found that convert_hf_to_gguf step load the whole model in memory!
I've tried to convert Scout on a 128GB machine it has errored out my other machie with 256GB did it. I wasn't sure what has caused it, but now when I've tried to convert Maverick the Ram filled up and failed.
(haven't looked at teh code yet)

@csabakecskemeti
Copy link
Contributor

Just saw @compilade has answered this already:

Lazy evaluation doesn't support splitting yet, so this will always eagerly evaluate (and so it will take more RAM than ideal during conversion).

tastelikefeet added a commit to tastelikefeet/llama.cpp that referenced this pull request Apr 10, 2025
* master: (123 commits)
  cuda : add f32 to bf16 copy op (ggml-org#12806)
  llava: improve clip_ctx destructor to not memleak load_image_size (ggml-org#12834)
  llama : fix FA when KV cache is not used (i.e. embeddings) (ggml-org#12825)
  server : fix thread.join() on exit (ggml-org#12831)
  llava: add more helper functions to check projector types in clip context (ggml-org#12824)
  arg : Including limits file on AIX (ggml-org#12822)
  server : webui : Improve Chat Input with Auto-Sizing Textarea (ggml-org#12785)
  Revert "sycl:remove redundant memcopy in function ggml_backend_sycl_buffer_set_tensor" (ggml-org#12812)
  gguf-py : support lazy tensor splitting (ggml-org#12809)
  llama : Support llama 4 text-only (ggml-org#12791)
  opencl: better identify Adreno GPU (ggml-org#12760)
  hellaswag: display estimated score confidence interval (ggml-org#12797)
  cuda : fix HIP and MUSA BF16 (#0)
  sync : ggml
  ggml : simplify Arm fp16 CPU logic (ggml/1177)
  CUDA: don't convert BF16 weights to FP32 (ggml/1174)
  cpu: move all the operators into a separate c++ file (except mul_mat) (ggml/1167)
  sycl: remove redundant memcopy in function ggml_backend_sycl_buffer_set_tensor (ggml-org#12734)
  ci : no curl on ggml-ci (ggml-org#12796)
  cmake : enable curl by default (ggml-org#12761)
  ...

# Conflicts:
#	common/arg.cpp
#	common/common.cpp
#	common/common.h
colout pushed a commit to colout/llama.cpp that referenced this pull request Apr 21, 2025
* llama4 conversion

* initial support, no chat template

* clean up a bit

* fix tokenizer conversion

* correct hparams

* try this

* fix shexp

* ffn_inp_normed

* chat template

* clean up model conversion

* add_bos

* add scale_before_ffn

* fix order

* weight_before_ffn

* llm_graph_input_attn_temp

* add chunk attn mask

* build_inp_attn_scale()

* add comment about ggml_repeat

* clarify comments

* fix build
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: llama 4