diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 565e342df..fbd913005 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -14,6 +14,7 @@ re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") +re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) paste_fields = {} bind_list = [] @@ -139,6 +140,30 @@ def run_bind(): ) +def find_hypernetwork_key(hypernet_name, hypernet_hash=None): + """Determines the config parameter name to use for the hypernet based on the parameters in the infotext. + + Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config + parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to. + + If the infotext has no hash, then a hypernet with the same name will be selected instead. + """ + hypernet_name = hypernet_name.lower() + if hypernet_hash is not None: + # Try to match the hash in the name + for hypernet_key in shared.hypernetworks.keys(): + result = re_hypernet_hash.search(hypernet_key) + if result is not None and result[1] == hypernet_hash: + return hypernet_key + else: + # Fall back to a hypernet with the same name + for hypernet_key in shared.hypernetworks.keys(): + if hypernet_key.lower().startswith(hypernet_name): + return hypernet_key + + return None + + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -188,6 +213,14 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res: res["Clip skip"] = "1" + if "Hypernet strength" not in res: + res["Hypernet strength"] = "1" + + if "Hypernet" in res: + hypernet_name = res["Hypernet"] + hypernet_hash = res.get("Hypernet hash", None) + res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + return res diff --git a/modules/processing.py b/modules/processing.py index d2288f26e..4a4060844 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -318,7 +318,7 @@ class Processed: return json.dumps(obj) - def infotext(self, p: StableDiffusionProcessing, index): + def infotext(self, p: StableDiffusionProcessing, index): return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) @@ -433,6 +433,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), + "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)), "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), @@ -450,7 +451,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" + negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()