diff --git a/modules/api/api.py b/modules/api/api.py index efedafa42..090838747 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin,Image -from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases +from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -541,12 +540,12 @@ class Api: return {} def unloadapi(self): - unload_model_weights() + sd_models.unload_model_weights() return {} def reloadapi(self): - reload_model_weights() + sd_models.send_model_to_device(shared.sd_model) return {} @@ -566,7 +565,7 @@ class Api: def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): diff --git a/modules/sd_models.py b/modules/sd_models.py index c8efeedca..3b6cdea18 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,6 @@ import collections import os.path import sys -import gc import threading import torch @@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None): - timer = Timer() - - if model_data.sd_model: - model_data.sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) - model_data.sd_model = None - sd_model = None - gc.collect() - devices.torch_gc() - - print(f"Unloaded weights {timer.summary()}.") + send_model_to_cpu(sd_model or shared.sd_model) return sd_model diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 74a3aef32..e054d00ab 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules.call_queue import wrap_gradio_call from modules.shared import opts from modules.ui_components import FormRow @@ -177,8 +177,8 @@ class UiSettings: download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model") with gr.Row(): calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash") calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1) @@ -194,16 +194,26 @@ class UiSettings: self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + def call_func_and_return_text(func, text): + def handler(): + t = timer.Timer() + func() + t.record(text) + + return f'{text} in {t.total:.1f}s' + + return handler + unload_sd_model.click( - fn=sd_models.unload_model_weights, + fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) reload_sd_model.click( - fn=sd_models.reload_model_weights, + fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) request_notifications.click(