Skip to content

Commit e38c4ec

Browse files
committed
Download checkpoint from HuggingFace
1 parent e58ce7e commit e38c4ec

File tree

6 files changed

+96
-37
lines changed

6 files changed

+96
-37
lines changed

examples/models/llama/export_llama_lib.py

+50
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@
9898
"phi_4_mini",
9999
]
100100
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
101+
HUGGING_FACE_REPO_IDS = {
102+
"qwen2_5": "Qwen/Qwen2.5-1.5B",
103+
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
104+
}
101105

102106

103107
class WeightType(Enum):
@@ -519,7 +523,53 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
519523
return return_val
520524

521525

526+
def download_and_convert_hf_checkpoint(modelname: str) -> str:
527+
"""
528+
Downloads and converts to Meta format a HuggingFace checkpoint.
529+
"""
530+
# Build cache path.
531+
cache_subdir = "meta_checkpoints"
532+
cache_dir = Path.home() / ".cache" / cache_subdir
533+
cache_dir.mkdir(parents=True, exist_ok=True)
534+
535+
# Use repo name to name the converted file.
536+
repo_id = HUGGING_FACE_REPO_IDS[modelname]
537+
model_name = repo_id.replace(
538+
"/", "_"
539+
)
540+
converted_path = cache_dir / f"{model_name}.pth"
541+
542+
if converted_path.exists():
543+
print(f"✔ Using cached converted model: {converted_path}")
544+
return converted_path
545+
546+
# 1. Download weights from Hugging Face.
547+
print("⬇ Downloading and converting checkpoint...")
548+
from huggingface_hub import snapshot_download
549+
550+
checkpoint_path = snapshot_download(
551+
repo_id=repo_id,
552+
)
553+
554+
# 2. Convert weights to Meta format.
555+
if modelname == "qwen2_5":
556+
from executorch.examples.models.qwen2_5 import convert_weights
557+
558+
convert_weights(checkpoint_path, converted_path)
559+
elif modelname == "phi_4_mini":
560+
from executorch.examples.models.phi_4_mini import convert_weights
561+
562+
convert_weights(checkpoint_path, converted_path)
563+
elif modelname == "smollm2":
564+
pass
565+
566+
return converted_path
567+
568+
522569
def export_llama(args) -> str:
570+
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:
571+
args.checkpoint = download_and_convert_hf_checkpoint(args.model)
572+
523573
if args.profile_path is not None:
524574
try:
525575
from executorch.util.python_profiler import CProfilerFlameGraph

examples/models/llama/install_requirements.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
pip install tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1414

1515
# Call the install helper for further setup
1616
python examples/models/llama/install_requirement_helper.py

examples/models/phi_4_mini/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# LICENSE file in the root directory of this source tree.
33

44
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.phi_4_mini.convert_weights import convert_weights
56

67

78
class Phi4MiniModel(Llama2Model):
@@ -11,4 +12,5 @@ def __init__(self, **kwargs):
1112

1213
__all__ = [
1314
"Phi4MiniModel",
15+
"convert_weights",
1416
]

examples/models/phi_4_mini/convert_weights.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -51,37 +51,40 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
5151
return converted_state_dict
5252

5353

54-
def main():
55-
parser = argparse.ArgumentParser(
56-
description="Convert Phi-4-mini weights to Meta format."
57-
)
58-
parser.add_argument(
59-
"input_dir",
60-
type=str,
61-
help="Path to directory containing checkpoint files",
62-
)
63-
parser.add_argument("output", type=str, help="Path to the output checkpoint")
64-
65-
args = parser.parse_args()
66-
54+
def convert_weights(input_dir: str, output_file: str) -> None:
55+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
6756
checkpointer = FullModelHFCheckpointer(
68-
checkpoint_dir=args.input_dir,
57+
checkpoint_dir=input_dir,
6958
checkpoint_files=[
7059
"model-00001-of-00002.safetensors",
7160
"model-00002-of-00002.safetensors",
7261
],
7362
output_dir=".",
74-
model_type="PHI3_MINI",
63+
model_type="PHI4",
7564
)
7665

7766
print("Loading checkpoint...")
7867
sd = checkpointer.load_checkpoint()
79-
8068
print("Converting checkpoint...")
8169
sd = phi_4_tune_to_meta(sd["model"])
70+
print("Saving checkpoint...")
71+
torch.save(sd, output_file)
72+
print("Done.")
8273

83-
torch.save(sd, args.output)
84-
print(f"Checkpoint saved to {args.output}")
74+
75+
def main():
76+
parser = argparse.ArgumentParser(
77+
description="Convert Phi-4-mini weights to Meta format."
78+
)
79+
parser.add_argument(
80+
"input_dir",
81+
type=str,
82+
help="Path to directory containing checkpoint files",
83+
)
84+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
85+
86+
args = parser.parse_args()
87+
convert_weights(args.input_dir, args.output)
8588

8689

8790
if __name__ == "__main__":

examples/models/qwen2_5/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# This source code is licensed under the BSD-style license found in the
22
# LICENSE file in the root directory of this source tree.
33

4-
from executorch.example.models.llama.model import Llama2Model
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.qwen2_5.convert_weights import convert_weights
56

67

78
class Qwen2_5Model(Llama2Model):
@@ -11,4 +12,5 @@ def __init__(self, **kwargs):
1112

1213
__all__ = [
1314
"Qwen2_5Model",
15+
"convert_weights",
1416
]

examples/models/qwen2_5/convert_weights.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -53,35 +53,37 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
5353
return converted_state_dict
5454

5555

56-
def main():
57-
parser = argparse.ArgumentParser(
58-
description="Convert Qwen2 weights to Meta format."
59-
)
60-
parser.add_argument(
61-
"input_dir",
62-
type=str,
63-
help="Path to directory containing checkpoint files",
64-
)
65-
parser.add_argument("output", type=str, help="Path to the output checkpoint")
66-
67-
args = parser.parse_args()
68-
56+
def convert_weights(input_dir: str, output_file: str) -> None:
6957
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
7058
checkpointer = FullModelHFCheckpointer(
71-
checkpoint_dir=args.input_dir,
59+
checkpoint_dir=input_dir,
7260
checkpoint_files=["model.safetensors"],
7361
output_dir=".",
7462
model_type="QWEN2",
7563
)
7664

7765
print("Loading checkpoint...")
7866
sd = checkpointer.load_checkpoint()
79-
8067
print("Converting checkpoint...")
8168
sd = qwen_2_tune_to_meta(sd["model"])
69+
print("Saving checkpoint...")
70+
torch.save(sd, output_file)
71+
print("Done.")
8272

83-
torch.save(sd, args.output)
84-
print(f"Checkpoint saved to {args.output}")
73+
74+
def main():
75+
parser = argparse.ArgumentParser(
76+
description="Convert Qwen2 weights to Meta format."
77+
)
78+
parser.add_argument(
79+
"input_dir",
80+
type=str,
81+
help="Path to directory containing checkpoint files",
82+
)
83+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
84+
85+
args = parser.parse_args()
86+
convert_weights(args.input_dir, args.output)
8587

8688

8789
if __name__ == "__main__":

0 commit comments

Comments
 (0)