add an option to copy config from one of models in checkpoint merger

This commit is contained in:
AUTOMATIC 2023-01-11 09:10:07 +03:00
parent 3e20244b0f
commit 954091697f
2 changed files with 35 additions and 4 deletions

View File

@ -3,6 +3,7 @@ import math
import os import os
import sys import sys
import traceback import traceback
import shutil
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -248,7 +249,32 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def create_config(ckpt_result, config_source, a, b, c):
def config(x):
return sd_models.find_checkpoint_config(x) if x else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
elif config_source == 1:
cfg = config(b)
elif config_source == 2:
cfg = config(c)
else:
cfg = None
if cfg is None:
return
filename, _ = os.path.splitext(ckpt_result)
checkpoint_filename = filename + ".yaml"
print("Copying config:")
print(" from:", cfg)
print(" to:", checkpoint_filename)
shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin() shared.state.begin()
shared.state.job = 'model-merge' shared.state.job = 'model-merge'
@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models() sd_models.list_models()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
print("Checkpoint saved.") print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end() shared.state.end()

View File

@ -1129,7 +1129,7 @@ def create_ui():
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>") gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row(): with FormRow():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@ -1143,11 +1143,13 @@ def create_ui():
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with gr.Row(): with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
@ -1703,6 +1705,7 @@ def create_ui():
save_as_half, save_as_half,
custom_name, custom_name,
checkpoint_format, checkpoint_format,
config_source,
], ],
outputs=[ outputs=[
submit_result, submit_result,