From 8803d04cfeb7574fe90c697df405015a3289771a Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 20 Apr 2023 17:07:07 +0300 Subject: [PATCH] fixed weight conversion --- utils/convert_original_stable_diffusion_to_diffusers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/utils/convert_original_stable_diffusion_to_diffusers.py b/utils/convert_original_stable_diffusion_to_diffusers.py index 0828e21..4078116 100644 --- a/utils/convert_original_stable_diffusion_to_diffusers.py +++ b/utils/convert_original_stable_diffusion_to_diffusers.py @@ -430,6 +430,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False else: output_block_list[layer_id] = [layer_name] + output_block_list = {x : sorted(y) for x, y in output_block_list} + if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] @@ -442,8 +444,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) - if ["conv.weight", "conv.bias"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ]