Skip to content

Commit 53e1f5c

Browse files
younesbelkadasgugger
authored andcommitted
[Trainer] Correct behavior of _load_best_model for PEFT models (#24103)
* v1 * some refactor - add ST format as well * fix * add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
1 parent 17db177 commit 53e1f5c

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/transformers/trainer.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@
134134
)
135135
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
136136
from .utils import (
137+
ADAPTER_SAFE_WEIGHTS_NAME,
138+
ADAPTER_WEIGHTS_NAME,
137139
CONFIG_NAME,
138140
SAFE_WEIGHTS_INDEX_NAME,
139141
SAFE_WEIGHTS_NAME,
@@ -2177,11 +2179,20 @@ def _load_best_model(self):
21772179
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
21782180
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
21792181
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2182+
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
2183+
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
2184+
21802185
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2181-
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
2186+
if (
2187+
os.path.exists(best_model_path)
2188+
or os.path.exists(best_safe_model_path)
2189+
or os.path.exists(best_adapter_model_path)
2190+
or os.path.exists(best_safe_adapter_model_path)
2191+
):
21822192
if self.is_deepspeed_enabled:
21832193
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
21842194
else:
2195+
has_been_loaded = True
21852196
if is_sagemaker_mp_enabled():
21862197
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
21872198
# If the 'user_content.pt' file exists, load with the new smp api.
@@ -2207,10 +2218,10 @@ def _load_best_model(self):
22072218
self.accelerator, model, self.state.best_model_checkpoint
22082219
)
22092220
else:
2210-
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
2211-
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
2221+
if is_peft_available() and isinstance(model, PeftModel):
2222+
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
22122223
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
2213-
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
2224+
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
22142225
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
22152226
# Load_adapter has no return value present, modify it when appropriate.
22162227
from torch.nn.modules.module import _IncompatibleKeys
@@ -2219,12 +2230,13 @@ def _load_best_model(self):
22192230
else:
22202231
logger.warning(
22212232
"The intermediate checkpoints of PEFT may not be saved correctly, "
2222-
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
2233+
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
22232234
"here are some examples https://github.com./huggingface/peft/issues/96"
22242235
)
2236+
has_been_loaded = False
22252237
else:
2226-
# We can't do pure 8bit training using transformers.
2227-
logger.warning("Could not loading a quantized checkpoint.")
2238+
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
2239+
has_been_loaded = False
22282240
else:
22292241
# We load the model state dict on the CPU to avoid an OOM error.
22302242
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
@@ -2236,7 +2248,7 @@ def _load_best_model(self):
22362248
# workaround for FSDP bug https://github.com./pytorch/pytorch/issues/82963
22372249
# which takes *args instead of **kwargs
22382250
load_result = model.load_state_dict(state_dict, False)
2239-
if not is_sagemaker_mp_enabled():
2251+
if not is_sagemaker_mp_enabled() and has_been_loaded:
22402252
self._issue_warnings_after_load(load_result)
22412253
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
22422254
load_result = load_sharded_checkpoint(

src/transformers/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@
177177

178178
WEIGHTS_NAME = "pytorch_model.bin"
179179
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
180+
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
181+
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
180182
TF2_WEIGHTS_NAME = "tf_model.h5"
181183
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
182184
TF_WEIGHTS_NAME = "model.ckpt"

0 commit comments

Comments
 (0)