From 681c450ecd8f0999cbaf562c5e734c7105320ad9 Mon Sep 17 00:00:00 2001 From: Mackerel Date: Sun, 4 Dec 2022 01:13:36 -0500 Subject: [PATCH] extras.py: use as little RAM as possible, misc fixes maximum of 2 models loaded at once. delete unneeded model before next step. fix 'teritary' -> 'tertiary'. gracefully fail when "add difference" is selected without a tertiary model --- modules/extras.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index bc349d5ea..0ad8deec5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -62,7 +62,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ # Also keep track of original file names imageNameArr = [] outputs = [] - + if extras_mode == 1: #convert file to pillow image for img in image_folder: @@ -234,7 +234,7 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): +def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -246,30 +246,25 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) + tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None) result_is_inpainting_model = False - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - - if teritary_model_info is not None: - print(f"Loading {teritary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu') - else: - theta_2 = None - theta_funcs = { "Weighted sum": (None, weighted_sum), "Add difference": (get_difference, add_difference), } theta_func1, theta_func2 = theta_funcs[interp_method] - print(f"Merging...") + if theta_func1 and not tertiary_model_info: + return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') if theta_func1: + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + for key in tqdm.tqdm(theta_1.keys()): if 'model' in key: if key in theta_2: @@ -277,7 +272,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_1[key] = theta_func1(theta_1[key], t2) else: theta_1[key] = torch.zeros_like(theta_1[key]) - del theta_2 + del theta_2 + + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: @@ -307,6 +307,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_0[key] = theta_1[key] if save_as_half: theta_0[key] = theta_0[key].half() + del theta_1 ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path @@ -332,5 +333,5 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam sd_models.list_models() - print(f"Checkpoint saved.") + print("Checkpoint saved.") return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]