repair unload sd checkpoint button
This commit is contained in:
parent
0d65d0eabd
commit
282903bb67
|
@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin, Image
|
||||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
|
@ -541,12 +540,12 @@ class Api:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
unload_model_weights()
|
sd_models.unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
reload_model_weights()
|
sd_models.send_model_to_device(shared.sd_model)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -566,7 +565,7 @@ class Api:
|
||||||
|
|
||||||
def set_config(self, req: dict[str, Any]):
|
def set_config(self, req: dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import collections
|
import collections
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
timer = Timer()
|
send_model_to_cpu(sd_model or shared.sd_model)
|
||||||
|
|
||||||
if model_data.sd_model:
|
|
||||||
model_data.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
|
||||||
model_data.sd_model = None
|
|
||||||
sd_model = None
|
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
|
||||||
from modules.call_queue import wrap_gradio_call
|
from modules.call_queue import wrap_gradio_call
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.ui_components import FormRow
|
from modules.ui_components import FormRow
|
||||||
|
@ -177,8 +177,8 @@ class UiSettings:
|
||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
|
||||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
|
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
|
||||||
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
|
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
|
||||||
|
@ -194,16 +194,26 @@ class UiSettings:
|
||||||
|
|
||||||
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||||
|
|
||||||
|
def call_func_and_return_text(func, text):
|
||||||
|
def handler():
|
||||||
|
t = timer.Timer()
|
||||||
|
func()
|
||||||
|
t.record(text)
|
||||||
|
|
||||||
|
return f'{text} in {t.total:.1f}s'
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
unload_sd_model.click(
|
unload_sd_model.click(
|
||||||
fn=sd_models.unload_model_weights,
|
fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[self.result]
|
||||||
)
|
)
|
||||||
|
|
||||||
reload_sd_model.click(
|
reload_sd_model.click(
|
||||||
fn=sd_models.reload_model_weights,
|
fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[self.result]
|
||||||
)
|
)
|
||||||
|
|
||||||
request_notifications.click(
|
request_notifications.click(
|
||||||
|
|
Loading…
Reference in New Issue