fixed weight conversion
This commit is contained in:
parent
a8520d6661
commit
8803d04cfe
|
@ -430,6 +430,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
output_block_list = {x : sorted(y) for x, y in output_block_list}
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||
|
@ -442,8 +444,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue