Skip to content

Add qwen 2.5 #8355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 24, 2025
13 changes: 10 additions & 3 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.max_batch_size = args.max_batch_size
self.max_context_len = args.max_context_len
self.dim = args.dim
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.attention_qkv_bias = args.attention_qkv_bias
self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
self.wv = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

self.layer_id = layer_id
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ModelArgs:
num_experts: int = 8 # Number of experts
num_activated_experts: int = 2 # Number of experts to activate
attention_type: str = "mha" # Attention type, registered in attention.py
attention_qkv_bias: bool = False
use_kv_cache: bool = False # Use key/value cache
use_sdpa_with_kv_cache_op: bool = (
False # Use custom sdpa op that updates kv cache in-place
Expand Down
11 changes: 7 additions & 4 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def apply_rotary_emb_to_k(
return xk_out.type_as(xk)


# Wrap apply_rotary_emb in a module to enable it to be module swapped out.
class RotaryEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -213,14 +214,20 @@ class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params

# Choose the appropriate RoPE implementation
if self.params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
scale_factor=self.params.rope_scale_factor,
)
self.apply_rotary_emb = RotaryEmbedding()

# Precompute frequencies
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
(
Expand All @@ -232,10 +239,6 @@ def __init__(self, params: ModelArgs):
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
if self.params.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = RotaryEmbedding()

def forward(
self,
Expand Down
7 changes: 4 additions & 3 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,23 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
self.dim = config.dim
self.head_dim = config.head_dim
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
self.attention_qkv_bias = config.attention_qkv_bias

self.wqs = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
for _ in range(self.n_heads)
]
)
self.wks = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
for _ in range(self.n_kv_heads)
]
)
self.wvs = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
for _ in range(self.n_kv_heads)
]
)
Expand Down
14 changes: 14 additions & 0 deletions examples/models/qwen2_5/1_5b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"dim": 1536,
"ffn_dim_multiplier": 1,
"hidden_dim": 8960,
"n_heads": 12,
"n_kv_heads": 2,
"n_layers": 28,
"norm_eps": 1e-06,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 151936,
"use_hf_rope": true,
"attention_qkv_bias": true
}
63 changes: 63 additions & 0 deletions examples/models/qwen2_5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
## Summary
Qwen 2.5 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. At the moment, 1.5b is currently supporting, with plans in the future for adding the 0.5b and 3b versions.

## Instructions

Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details.

All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args:
```
--model qwen2_5
--params examples/models/qwen2_5/1_5b_config.json
--checkpoint <path-to-meta-checkpoint>
```

### Generate the Checkpoint
The original checkpoint can be obtained from HuggingFace:
```
huggingface-cli download Qwen/Qwen2.5-1.5B
```

We then convert it to Meta's checkpoint format:
```
python examples/models/qwen2_5/convert_weights.py <path-to-checkpoint-dir> <output-path>
```

### Example export and run
Here is an basic example for exporting and running Qwen 2.5, although please refer to [Llama README page](../llama/README.md) for more advanced usage.

Export to XNNPack, no quantization:
```
# No quantization
# Set these paths to point to the downloaded files
QWEN_CHECKPOINT=path/to/checkpoint.pth

python -m examples.models.llama.export_llama \
--model "qwen2_5" \
--checkpoint "${QWEN_CHECKPOINT:?}" \
--params examples/models/qwen2_5/1_5b_config.json \
-kv \
--use_sdpa_with_kv_cache \
-d fp32 \
-X \
--metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \
--output_name="qwen2_5-1_5b.pte"
--verbose
```

Run using the executor runner:
```
# Currently a work in progress, just need to enable HuggingFace json tokenizer in C++.
# In the meantime, can run with an example Python runner with pybindings:

python -m examples.models.llama.runner.native
--model qwen2_5
--pte <path-to-pte>
-kv
--tokenizer <path-to-tokenizer>/tokenizer.json
--tokenizer_config <path-to_tokenizer>/tokenizer_config.json
--prompt "Who is the founder of Meta?"
--params examples/models/qwen2_5/1_5b_config.json
--max_len 64
--temperature 0
```
90 changes: 90 additions & 0 deletions examples/models/qwen2_5/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

from torchtune.training import FullModelHFCheckpointer

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_QWEN_2_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}


def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()}

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

# 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com./QwenLM/Qwen2.5/issues/733.
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def main():
parser = argparse.ArgumentParser(
description="Convert Qwen2 weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
# checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/",
checkpoint_dir=args.input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="QWEN2",
)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = qwen_2_tune_to_meta(sd["model"])
# torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")


if __name__ == "__main__":
main()
Loading