298 lines
12 KiB
Python
298 lines
12 KiB
Python
import argparse
|
|
|
|
import huggingface_hub
|
|
import k_diffusion as K
|
|
import torch
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
|
UPSCALER_REPO = "pcuenq/k-upscaler"
|
|
|
|
|
|
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
|
rv = {
|
|
# norm1
|
|
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
|
|
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
|
|
# conv1
|
|
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
|
|
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
|
|
# norm2
|
|
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
|
|
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
|
|
# conv2
|
|
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
|
|
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
|
|
}
|
|
|
|
if resnet.conv_shortcut is not None:
|
|
rv.update(
|
|
{
|
|
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
|
|
}
|
|
)
|
|
|
|
return rv
|
|
|
|
|
|
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
|
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
|
|
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
|
|
rv = {
|
|
# norm
|
|
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
|
|
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
|
|
# to_q
|
|
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
|
|
# to_k
|
|
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
|
|
# to_v
|
|
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
|
|
# to_out
|
|
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
|
|
.squeeze(-1)
|
|
.squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
|
|
}
|
|
|
|
return rv
|
|
|
|
|
|
def cross_attn_to_diffusers_checkpoint(
|
|
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
|
|
):
|
|
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
|
|
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
|
|
|
|
rv = {
|
|
# norm2 (ada groupnorm)
|
|
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
|
|
f"{attention_prefix}.norm_dec.mapper.weight"
|
|
],
|
|
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
|
|
f"{attention_prefix}.norm_dec.mapper.bias"
|
|
],
|
|
# layernorm on encoder_hidden_state
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
|
|
f"{attention_prefix}.norm_enc.weight"
|
|
],
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
|
|
f"{attention_prefix}.norm_enc.bias"
|
|
],
|
|
# to_q
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
|
|
f"{attention_prefix}.q_proj.weight"
|
|
]
|
|
.squeeze(-1)
|
|
.squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
|
|
f"{attention_prefix}.q_proj.bias"
|
|
],
|
|
# to_k
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
|
|
# to_v
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
|
|
# to_out
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
|
|
f"{attention_prefix}.out_proj.weight"
|
|
]
|
|
.squeeze(-1)
|
|
.squeeze(-1),
|
|
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
|
|
f"{attention_prefix}.out_proj.bias"
|
|
],
|
|
}
|
|
|
|
return rv
|
|
|
|
|
|
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
|
|
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
|
|
block_prefix = f"{block_prefix}.{block_idx}"
|
|
|
|
diffusers_checkpoint = {}
|
|
|
|
if not hasattr(block, "attentions"):
|
|
n = 1 # resnet only
|
|
elif not block.attentions[0].add_self_attention:
|
|
n = 2 # resnet -> cross-attention
|
|
else:
|
|
n = 3 # resnet -> self-attention -> cross-attention)
|
|
|
|
for resnet_idx, resnet in enumerate(block.resnets):
|
|
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
|
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
|
|
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
|
|
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
|
|
|
|
diffusers_checkpoint.update(
|
|
resnet_to_diffusers_checkpoint(
|
|
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
|
)
|
|
)
|
|
|
|
if hasattr(block, "attentions"):
|
|
for attention_idx, attention in enumerate(block.attentions):
|
|
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
|
|
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
|
|
self_attention_prefix = f"{block_prefix}.{idx}"
|
|
cross_attention_prefix = f"{block_prefix}.{idx }"
|
|
cross_attention_index = 1 if not attention.add_self_attention else 2
|
|
idx = (
|
|
n * attention_idx + cross_attention_index
|
|
if block_type == "up"
|
|
else n * attention_idx + cross_attention_index + 1
|
|
)
|
|
cross_attention_prefix = f"{block_prefix}.{idx }"
|
|
|
|
diffusers_checkpoint.update(
|
|
cross_attn_to_diffusers_checkpoint(
|
|
checkpoint,
|
|
diffusers_attention_prefix=diffusers_attention_prefix,
|
|
diffusers_attention_index=2,
|
|
attention_prefix=cross_attention_prefix,
|
|
)
|
|
)
|
|
|
|
if attention.add_self_attention is True:
|
|
diffusers_checkpoint.update(
|
|
self_attn_to_diffusers_checkpoint(
|
|
checkpoint,
|
|
diffusers_attention_prefix=diffusers_attention_prefix,
|
|
attention_prefix=self_attention_prefix,
|
|
)
|
|
)
|
|
|
|
return diffusers_checkpoint
|
|
|
|
|
|
def unet_to_diffusers_checkpoint(model, checkpoint):
|
|
diffusers_checkpoint = {}
|
|
|
|
# pre-processing
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"conv_in.weight": checkpoint["inner_model.proj_in.weight"],
|
|
"conv_in.bias": checkpoint["inner_model.proj_in.bias"],
|
|
}
|
|
)
|
|
|
|
# timestep and class embedding
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
|
|
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
|
|
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
|
|
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
|
|
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
|
|
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
|
|
}
|
|
)
|
|
|
|
# down_blocks
|
|
for down_block_idx, down_block in enumerate(model.down_blocks):
|
|
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
|
|
|
|
# up_blocks
|
|
for up_block_idx, up_block in enumerate(model.up_blocks):
|
|
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
|
|
|
|
# post-processing
|
|
diffusers_checkpoint.update(
|
|
{
|
|
"conv_out.weight": checkpoint["inner_model.proj_out.weight"],
|
|
"conv_out.bias": checkpoint["inner_model.proj_out.bias"],
|
|
}
|
|
)
|
|
|
|
return diffusers_checkpoint
|
|
|
|
|
|
def unet_model_from_original_config(original_config):
|
|
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
|
|
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
|
|
|
|
block_out_channels = original_config["channels"]
|
|
|
|
assert (
|
|
len(set(original_config["depths"])) == 1
|
|
), "UNet2DConditionModel currently do not support blocks with different number of layers"
|
|
layers_per_block = original_config["depths"][0]
|
|
|
|
class_labels_dim = original_config["mapping_cond_dim"]
|
|
cross_attention_dim = original_config["cross_cond_dim"]
|
|
|
|
attn1_types = []
|
|
attn2_types = []
|
|
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
|
|
if s:
|
|
a1 = "self"
|
|
a2 = "cross" if c else None
|
|
elif c:
|
|
a1 = "cross"
|
|
a2 = None
|
|
else:
|
|
a1 = None
|
|
a2 = None
|
|
attn1_types.append(a1)
|
|
attn2_types.append(a2)
|
|
|
|
unet = UNet2DConditionModel(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
|
|
mid_block_type=None,
|
|
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
|
|
block_out_channels=block_out_channels,
|
|
layers_per_block=layers_per_block,
|
|
act_fn="gelu",
|
|
norm_num_groups=None,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=64,
|
|
time_cond_proj_dim=class_labels_dim,
|
|
resnet_time_scale_shift="scale_shift",
|
|
time_embedding_type="fourier",
|
|
timestep_post_act="gelu",
|
|
conv_in_kernel=1,
|
|
conv_out_kernel=1,
|
|
)
|
|
|
|
return unet
|
|
|
|
|
|
def main(args):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
|
|
orig_weights_path = huggingface_hub.hf_hub_download(
|
|
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
|
|
)
|
|
print(f"loading original model configuration from {orig_config_path}")
|
|
print(f"loading original model checkpoint from {orig_weights_path}")
|
|
|
|
print("converting to diffusers unet")
|
|
orig_config = K.config.load_config(open(orig_config_path))["model"]
|
|
model = unet_model_from_original_config(orig_config)
|
|
|
|
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
|
|
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
|
|
|
|
model.load_state_dict(converted_checkpoint, strict=True)
|
|
model.save_pretrained(args.dump_path)
|
|
print(f"saving converted unet model in {args.dump_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|