diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08e..fea84630c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") @@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + shared.sd_model = load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5cca168a1..eff0c942a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) + p.sd_model = shared.sd_model def confirm_checkpoints(p, xs):