From 7b052eb70eb2a35ce4f776b1e2ab1389802a41b5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 10:07:02 +0300 Subject: [PATCH] add resolution calculation from buckets for lora user metadata page --- extensions-builtin/Lora/lora.py | 1 - .../Lora/ui_edit_user_metadata.py | 28 +++++++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index c87109221..467ad65f2 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -86,7 +86,6 @@ class LoraOnDisk: if self.is_safetensors: try: - #self.metadata = sd_models.read_metadata_from_safetensors(filename) self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) except Exception as e: errors.display(e, f"reading lora {filename}") diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 6db63b095..354a1d686 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -65,17 +65,33 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) item = self.page.items.get(name, {}) metadata = item.get("metadata") or {} - keys = [ - ('ss_sd_model_name', "Model:"), - ('ss_resolution', "Resolution:"), - ('ss_clip_skip', "Clip skip:"), - ] + keys = { + 'ss_sd_model_name': "Model:", + 'ss_clip_skip': "Clip skip:", + } - for key, label in keys: + for key, label in keys.items(): value = metadata.get(key, None) if value is not None and str(value) != "None": table.append((label, html.escape(value))) + ss_bucket_info = metadata.get("ss_bucket_info") + if ss_bucket_info and "buckets" in ss_bucket_info: + resolutions = {} + for _, bucket in ss_bucket_info["buckets"].items(): + resolution = bucket["resolution"] + resolution = f'{resolution[1]}x{resolution[0]}' + + resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"]) + + resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True) + resolutions_text = html.escape(", ".join(resolutions_list[0:4])) + if len(resolutions) > 4: + resolutions_text += ", ..." + resolutions_text = f"{resolutions_text}" + + table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text)) + image_count = 0 for _, params in metadata.get("ss_dataset_dirs", {}).items(): image_count += int(params.get("img_count", 0))