@@ -51,37 +51,40 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
51
51
return converted_state_dict
52
52
53
53
54
- def main ():
55
- parser = argparse .ArgumentParser (
56
- description = "Convert Phi-4-mini weights to Meta format."
57
- )
58
- parser .add_argument (
59
- "input_dir" ,
60
- type = str ,
61
- help = "Path to directory containing checkpoint files" ,
62
- )
63
- parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
64
-
65
- args = parser .parse_args ()
66
-
54
+ def convert_weights (input_dir : str , output_file : str ) -> None :
55
+ # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
67
56
checkpointer = FullModelHFCheckpointer (
68
- checkpoint_dir = args . input_dir ,
57
+ checkpoint_dir = input_dir ,
69
58
checkpoint_files = [
70
59
"model-00001-of-00002.safetensors" ,
71
60
"model-00002-of-00002.safetensors" ,
72
61
],
73
62
output_dir = "." ,
74
- model_type = "PHI3_MINI " ,
63
+ model_type = "PHI4 " ,
75
64
)
76
65
77
66
print ("Loading checkpoint..." )
78
67
sd = checkpointer .load_checkpoint ()
79
-
80
68
print ("Converting checkpoint..." )
81
69
sd = phi_4_tune_to_meta (sd ["model" ])
70
+ print ("Saving checkpoint..." )
71
+ torch .save (sd , output_file )
72
+ print ("Done." )
82
73
83
- torch .save (sd , args .output )
84
- print (f"Checkpoint saved to { args .output } " )
74
+
75
+ def main ():
76
+ parser = argparse .ArgumentParser (
77
+ description = "Convert Phi-4-mini weights to Meta format."
78
+ )
79
+ parser .add_argument (
80
+ "input_dir" ,
81
+ type = str ,
82
+ help = "Path to directory containing checkpoint files" ,
83
+ )
84
+ parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
85
+
86
+ args = parser .parse_args ()
87
+ convert_weights (args .input_dir , args .output )
85
88
86
89
87
90
if __name__ == "__main__" :
0 commit comments