From 246c269af87757998f57bb27ddda59fdc7cff976 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 23 Apr 2024 03:08:09 +0900 Subject: [PATCH 1/2] add option to check file hash after download if the sha256 hash does not match it will be automatically deleted --- modules/modelloader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index 115415c8e..5421e59b0 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -23,6 +23,7 @@ def load_file_from_url( model_dir: str, progress: bool = True, file_name: str | None = None, + hash_prefix: str | None = None, ) -> str: """Download a file from `url` into `model_dir`, using the file present if possible. @@ -36,11 +37,11 @@ def load_file_from_url( if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') from torch.hub import download_url_to_file - download_url_to_file(url, cached_file, progress=progress) + download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix) return cached_file -def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -49,6 +50,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None @param model_path: The location to store/find models in. @param command_path: A command-line argument to search for models in first. @param ext_filter: An optional list of filename extensions to filter by + @param hash_prefix: the expected sha256 of the model_url @return: A list of paths containing the desired model(s) """ output = [] @@ -78,7 +80,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name, hash_prefix=hash_prefix)) else: output.append(model_url) From c69773d7e8f23f8b6c46a8e177b50386e1f1b8e8 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 23 Apr 2024 03:08:57 +0900 Subject: [PATCH 2/2] ensure integrity for initial sd model download --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a6..35d5952af 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -149,10 +149,12 @@ def list_models(): cmd_ckpt = shared.cmd_opts.ckpt if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): model_url = None + expected_sha256 = None else: model_url = f"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" + expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa' - model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"], hash_prefix=expected_sha256) if os.path.exists(cmd_ckpt): checkpoint_info = CheckpointInfo(cmd_ckpt)