fixed weight conversion

This commit is contained in:
Alexander 2023-04-20 17:07:07 +03:00
parent a8520d6661
commit 8803d04cfe
1 changed files with 4 additions and 2 deletions

View File

@ -430,6 +430,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
else:
output_block_list[layer_id] = [layer_name]
output_block_list = {x : sorted(y) for x, y in output_block_list}
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
@ -442,8 +444,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]