diff --git a/modules/extras.py b/modules/extras.py index fe701a0e6..d03f976e3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -278,6 +278,13 @@ def create_config(ckpt_result, config_source, a, b, c): chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): shared.state.begin() shared.state.job = 'model-merge' @@ -400,8 +407,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ else: theta_0[key] = theta_func2(a, b, multiplier) - if save_as_half: - theta_0[key] = theta_0[key].half() + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 @@ -416,10 +422,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in vae_dict.keys(): theta_0_key = 'first_stage_model.' + key if theta_0_key in theta_0: - theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key] + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) del vae_dict + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path filename = filename_generator() if custom_name == '' else custom_name