Predicting from checkpoint with CLI #20715
Replies: 2 comments
-
After giving this some more thought, I suppose the generated |
Beta Was this translation helpful? Give feedback.
-
My current workaround is class MyLightningCLI(LightningCLI):
def after_instantiate_classes(self) -> None:
if self.subcommand == "predict" and self.config["predict"]["ckpt_path"]:
logger.warning("Hijacking model after instantiate.")
config_from_cli = self.parser.parse_args(self.parser.args, defaults=False)
model_args = config_from_cli["predict"].get("model", {})
logger.debug(f"{model_args=}")
self.model = MyModel.load_from_checkpoint(
self.config["predict"]["ckpt_path"], **model_args
)
logger.debug(f"Model loaded from checkpoint: {self.model}") This is not great, but works for my current use-case. I suppose the |
Beta Was this translation helpful? Give feedback.
-
I want to predict with a model for which I have a checkpoint. The arguments used to train the model (producing the checkpoint), causes the model to have different layer sizes than the default values. When using
ckpt_path
, this then causes the loading to fail, due to a size mismatch.Two obvious workarounds are
predict
call (--model.some_layer_size 4
above), orLightningModule.load_from_checkpoint
.Explicitly giving all the flags becomes very cumbersome when there are many. The second approach might be a good one, though I haven't yet tried implementing it.
Is there something I am overlooking here? What is the recommended way to predict from a checkpoint like this?
Any input greatly appreciated:)
Beta Was this translation helpful? Give feedback.
All reactions