add metadata to checkpoint merger

This commit is contained in:
AUTOMATIC1111 2023-08-01 08:27:54 +03:00
parent 6d3a0c9506
commit 07be13caa3
3 changed files with 52 additions and 9 deletions

View File

@ -7,7 +7,7 @@ import json
import torch import torch
import tqdm import tqdm
from modules import shared, images, sd_models, sd_vae, sd_models_config from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html from modules.ui_common import plaintext_to_html
import gradio as gr import gradio as gr
import safetensors.torch import safetensors.torch
@ -72,7 +72,20 @@ def to_half(tensor, enable):
return tensor return tensor
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
metadata = {}
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
if checkpoint_info is None:
continue
metadata.update(checkpoint_info.metadata)
return json.dumps(metadata, indent=4, ensure_ascii=False)
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
shared.state.begin(job="model-merge") shared.state.begin(job="model-merge")
def fail(message): def fail(message):
@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving" shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
metadata = None metadata = {}
if save_metadata and copy_metadata_fields:
if primary_model_info:
metadata.update(primary_model_info.metadata)
if secondary_model_info:
metadata.update(secondary_model_info.metadata)
if tertiary_model_info:
metadata.update(tertiary_model_info.metadata)
if save_metadata: if save_metadata:
metadata = {"format": "pt"} try:
metadata.update(json.loads(metadata_json))
except Exception as e:
errors.display(e, "readin metadata from json")
metadata["format"] = "pt"
if save_metadata and add_merge_recipe:
merge_recipe = { merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger "type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256, "primary_model_hash": primary_model_info.sha256,
@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting": result_is_inpainting_model, "is_inpainting": result_is_inpainting_model,
"is_instruct_pix2pix": result_is_instruct_pix2pix_model "is_instruct_pix2pix": result_is_instruct_pix2pix_model
} }
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {} sd_merge_models = {}
@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info: if tertiary_model_info:
add_model_metadata(tertiary_model_info) add_model_metadata(tertiary_model_info)
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
metadata["sd_merge_models"] = json.dumps(sd_merge_models) metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
else: else:
torch.save(theta_0, output_modelname) torch.save(theta_0, output_modelname)

View File

@ -85,7 +85,7 @@ class CheckpointInfo:
if self.shorthash not in self.ids: if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
checkpoints_list.pop(self.title) checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]' self.title = f'{self.name} [{self.shorthash}]'
self.register() self.register()

View File

@ -51,7 +51,6 @@ class UiCheckpointMerger:
with FormRow(): with FormRow():
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
self.save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
with FormRow(): with FormRow():
with gr.Column(): with gr.Column():
@ -65,16 +64,30 @@ class UiCheckpointMerger:
with FormRow(): with FormRow():
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
with gr.Row(): with gr.Accordion("Metadata", open=False) as metadata_editor:
with FormRow():
self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
self.read_metadata = gr.Button("Read metadata from selected checkpoints")
with FormRow():
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='compact', elem_id="modelmerger_results_container"): with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
with gr.Group(elem_id="modelmerger_results_panel"): with gr.Group(elem_id="modelmerger_results_panel"):
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
self.metadata_editor = metadata_editor
self.blocks = modelmerger_interface self.blocks = modelmerger_interface
def setup_ui(self, dummy_component, sd_model_checkpoint_component): def setup_ui(self, dummy_component, sd_model_checkpoint_component):
self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result]) self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
self.modelmerger_merge.click( self.modelmerger_merge.click(
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
@ -93,6 +106,9 @@ class UiCheckpointMerger:
self.bake_in_vae, self.bake_in_vae,
self.discard_weights, self.discard_weights,
self.save_metadata, self.save_metadata,
self.add_merge_recipe,
self.copy_metadata_fields,
self.metadata_json,
], ],
outputs=[ outputs=[
self.primary_model_name, self.primary_model_name,