diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef1..120838480 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] - self.title = name + self.name = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.shorthash = self.sha256[0:10] if self.sha256 else None - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256] self.register() + self.title = f'{self.name} [{self.shorthash}]' + return self.shorthash @@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo): + title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + if checkpoint_info.title != title: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title cache_enabled = shared.opts.sd_checkpoint_cache > 0 diff --git a/modules/ui.py b/modules/ui.py index 0c5ba3585..13d80ae27 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -439,7 +439,7 @@ def apply_setting(key, value): opts.data_labels[key].onchange() opts.save(shared.config_filename) - return value + return getattr(opts, key) def update_generation_info(generation_info, html_info, img_index): @@ -597,6 +597,16 @@ def ordered_ui_categories(): yield category +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + def create_ui(): import modules.img2img import modules.txt2img @@ -1600,7 +1610,7 @@ def create_ui(): opts.save(shared.config_filename) - return gr.update(value=value), opts.dumpjson() + return get_value_for_setting(key), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Row(): @@ -1771,15 +1781,6 @@ def create_ui(): component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - def get_value_for_setting(key): - value = getattr(opts, key) - - info = opts.data_labels[key] - args = info.component_args() if callable(info.component_args) else info.component_args or {} - args = {k: v for k, v in args.items() if k not in {'precision'}} - - return gr.update(value=value, **args) - def get_settings_values(): return [get_value_for_setting(key) for key in component_keys]