134
134
)
135
135
from .training_args import OptimizerNames , ParallelMode , TrainingArguments
136
136
from .utils import (
137
+ ADAPTER_SAFE_WEIGHTS_NAME ,
138
+ ADAPTER_WEIGHTS_NAME ,
137
139
CONFIG_NAME ,
138
140
SAFE_WEIGHTS_INDEX_NAME ,
139
141
SAFE_WEIGHTS_NAME ,
@@ -2177,11 +2179,20 @@ def _load_best_model(self):
2177
2179
logger .info (f"Loading best model from { self .state .best_model_checkpoint } (score: { self .state .best_metric } )." )
2178
2180
best_model_path = os .path .join (self .state .best_model_checkpoint , WEIGHTS_NAME )
2179
2181
best_safe_model_path = os .path .join (self .state .best_model_checkpoint , SAFE_WEIGHTS_NAME )
2182
+ best_adapter_model_path = os .path .join (self .state .best_model_checkpoint , ADAPTER_WEIGHTS_NAME )
2183
+ best_safe_adapter_model_path = os .path .join (self .state .best_model_checkpoint , ADAPTER_SAFE_WEIGHTS_NAME )
2184
+
2180
2185
model = self .model_wrapped if is_sagemaker_mp_enabled () else self .model
2181
- if os .path .exists (best_model_path ) or os .path .exists (best_safe_model_path ):
2186
+ if (
2187
+ os .path .exists (best_model_path )
2188
+ or os .path .exists (best_safe_model_path )
2189
+ or os .path .exists (best_adapter_model_path )
2190
+ or os .path .exists (best_safe_adapter_model_path )
2191
+ ):
2182
2192
if self .is_deepspeed_enabled :
2183
2193
deepspeed_load_checkpoint (self .model_wrapped , self .state .best_model_checkpoint )
2184
2194
else :
2195
+ has_been_loaded = True
2185
2196
if is_sagemaker_mp_enabled ():
2186
2197
if os .path .isfile (os .path .join (self .state .best_model_checkpoint , "user_content.pt" )):
2187
2198
# If the 'user_content.pt' file exists, load with the new smp api.
@@ -2207,10 +2218,10 @@ def _load_best_model(self):
2207
2218
self .accelerator , model , self .state .best_model_checkpoint
2208
2219
)
2209
2220
else :
2210
- if hasattr ( model , "base_model" ) and getattr (model . base_model , "is_8bit_serializable" , False ):
2211
- # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
2221
+ if is_peft_available ( ) and isinstance (model , PeftModel ):
2222
+ # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
2212
2223
if hasattr (model , "active_adapter" ) and hasattr (model , "load_adapter" ):
2213
- if os .path .exists (os .path .join ( self . state . best_model_checkpoint , "adapter_model.bin" ) ):
2224
+ if os .path .exists (best_adapter_model_path ) or os .path .exists ( best_safe_adapter_model_path ):
2214
2225
model .load_adapter (self .state .best_model_checkpoint , model .active_adapter )
2215
2226
# Load_adapter has no return value present, modify it when appropriate.
2216
2227
from torch .nn .modules .module import _IncompatibleKeys
@@ -2219,12 +2230,13 @@ def _load_best_model(self):
2219
2230
else :
2220
2231
logger .warning (
2221
2232
"The intermediate checkpoints of PEFT may not be saved correctly, "
2222
- "using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
2233
+ f "using `TrainerCallback` to save { ADAPTER_WEIGHTS_NAME } in corresponding folders, "
2223
2234
"here are some examples https://github.com./huggingface/peft/issues/96"
2224
2235
)
2236
+ has_been_loaded = False
2225
2237
else :
2226
- # We can't do pure 8bit training using transformers.
2227
- logger . warning ( "Could not loading a quantized checkpoint." )
2238
+ logger . warning ( "Could not load adapter model, make sure to have `peft>=0.3.0` installed" )
2239
+ has_been_loaded = False
2228
2240
else :
2229
2241
# We load the model state dict on the CPU to avoid an OOM error.
2230
2242
if self .args .save_safetensors and os .path .isfile (best_safe_model_path ):
@@ -2236,7 +2248,7 @@ def _load_best_model(self):
2236
2248
# workaround for FSDP bug https://github.com./pytorch/pytorch/issues/82963
2237
2249
# which takes *args instead of **kwargs
2238
2250
load_result = model .load_state_dict (state_dict , False )
2239
- if not is_sagemaker_mp_enabled ():
2251
+ if not is_sagemaker_mp_enabled () and has_been_loaded :
2240
2252
self ._issue_warnings_after_load (load_result )
2241
2253
elif os .path .exists (os .path .join (self .state .best_model_checkpoint , WEIGHTS_INDEX_NAME )):
2242
2254
load_result = load_sharded_checkpoint (
0 commit comments