From 3f7edc5f724862cce8d43bca1b531d962e963a3a Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Wed, 9 Nov 2022 18:08:30 +0700 Subject: [PATCH] Fix layer names convert LDM script (#1206) fix script convert LDM --- ...rt_ldm_original_checkpoint_to_diffusers.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/scripts/convert_ldm_original_checkpoint_to_diffusers.py index 52865792..f547e96f 100644 --- a/scripts/convert_ldm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ldm_original_checkpoint_to_diffusers.py @@ -112,9 +112,9 @@ def assign_to_checkpoint( continue # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid.resnets.0") - new_path = new_path.replace("middle_block.1", "mid.attentions.0") - new_path = new_path.replace("middle_block.2", "mid.resnets.1") + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: @@ -175,15 +175,16 @@ def convert_ldm_checkpoint(checkpoint, config): attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in checkpoint: - new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ f"input_blocks.{i}.0.op.weight" ] - new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ f"input_blocks.{i}.0.op.bias" ] + continue paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"} assign_to_checkpoint( paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config @@ -193,18 +194,18 @@ def convert_ldm_checkpoint(checkpoint, config): paths = renew_attention_paths(attentions) meta_path = { "old": f"input_blocks.{i}.1", - "new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", } to_split = { f"input_blocks.{i}.1.qkv.bias": { - "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", - "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", - "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", + "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", + "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", + "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", }, f"input_blocks.{i}.1.qkv.weight": { - "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", - "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", - "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", + "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", + "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", + "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", }, } assign_to_checkpoint(