show metadata for SD checkpoints in the extra networks UI
This commit is contained in:
parent
c10633f93a
commit
4b43480fe8
|
@ -14,7 +14,7 @@ import ldm.modules.midas as midas
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
@ -33,6 +33,8 @@ class CheckpointInfo:
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
|
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
|
@ -43,6 +45,19 @@ class CheckpointInfo:
|
||||||
if name.startswith("\\") or name.startswith("/"):
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = read_metadata_from_safetensors(filename)
|
||||||
|
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading metadata for {filename}")
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
|
@ -55,15 +70,6 @@ class CheckpointInfo:
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
_, ext = os.path.splitext(self.filename)
|
|
||||||
if ext.lower() == ".safetensors":
|
|
||||||
try:
|
|
||||||
self.metadata = read_metadata_from_safetensors(filename)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
for id in self.ids:
|
for id in self.ids:
|
||||||
|
|
|
@ -23,6 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
|
"metadata": checkpoint.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue