Action to calculate all SD checkpoint hashes

This commit is contained in:
w-e-w 2023-09-01 00:55:56 +09:00
parent d39440bfb9
commit 348c6022f3
1 changed files with 19 additions and 0 deletions

View File

@ -5,6 +5,7 @@ from modules.call_queue import wrap_gradio_call
from modules.shared import opts from modules.shared import opts
from modules.ui_components import FormRow from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript from modules.ui_gradio_extensions import reload_javascript
from concurrent.futures import ThreadPoolExecutor, as_completed
def get_value_for_setting(key): def get_value_for_setting(key):
@ -175,6 +176,9 @@ class UiSettings:
with gr.Row(): with gr.Row():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.Row():
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
@ -241,6 +245,21 @@ class UiSettings:
outputs=[sysinfo_check_output], outputs=[sysinfo_check_output],
) )
def calculate_all_checkpoint_hash_fn(max_thread):
checkpoints_list = sd_models.checkpoints_list.values()
with ThreadPoolExecutor(max_workers=max_thread) as executor:
futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]
completed = 0
for _ in as_completed(futures):
completed += 1
print(f"{completed} / {len(checkpoints_list)} ")
print("Finish calculating hash for all checkpoints")
calculate_all_checkpoint_hash.click(
fn=calculate_all_checkpoint_hash_fn,
inputs=[calculate_all_checkpoint_hash_threads],
)
self.interface = settings_interface self.interface = settings_interface
def add_quicksettings(self): def add_quicksettings(self):