114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
|
|
from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
|
|
from transformers import CLIPTextConfig, GPT2Tokenizer
|
|
|
|
|
|
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
|
|
state_dict = torch.load("base.pt", map_location="cpu")
|
|
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
|
|
|
|
### Convert the text encoder
|
|
|
|
config = CLIPTextConfig(
|
|
vocab_size=50257,
|
|
max_position_embeddings=128,
|
|
hidden_size=512,
|
|
intermediate_size=2048,
|
|
num_hidden_layers=16,
|
|
num_attention_heads=8,
|
|
use_padding_embeddings=True,
|
|
)
|
|
model = CLIPTextModel(config).eval()
|
|
tokenizer = GPT2Tokenizer(
|
|
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
|
|
)
|
|
|
|
hf_encoder = model.text_model
|
|
|
|
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
|
|
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
|
|
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
|
|
|
|
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
|
|
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
|
|
|
|
for layer_idx in range(config.num_hidden_layers):
|
|
hf_layer = hf_encoder.encoder.layers[layer_idx]
|
|
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
|
|
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
|
|
|
|
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
|
|
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
|
|
|
|
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
|
|
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
|
|
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
|
|
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
|
|
|
|
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
|
|
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
|
|
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
|
|
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
|
|
|
|
### Convert the Text-to-Image UNet
|
|
|
|
text2im_model = GlideTextToImageUNetModel(
|
|
in_channels=3,
|
|
model_channels=192,
|
|
out_channels=6,
|
|
num_res_blocks=3,
|
|
attention_resolutions=(2, 4, 8),
|
|
dropout=0.1,
|
|
channel_mult=(1, 2, 3, 4),
|
|
num_heads=1,
|
|
num_head_channels=64,
|
|
num_heads_upsample=1,
|
|
use_scale_shift_norm=True,
|
|
resblock_updown=True,
|
|
transformer_dim=512,
|
|
)
|
|
|
|
text2im_model.load_state_dict(state_dict, strict=False)
|
|
|
|
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
|
|
|
|
### Convert the Super-Resolution UNet
|
|
|
|
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
|
|
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
|
|
|
|
superres_model = GlideSuperResUNetModel(
|
|
in_channels=6,
|
|
model_channels=192,
|
|
out_channels=6,
|
|
num_res_blocks=2,
|
|
attention_resolutions=(8, 16, 32),
|
|
dropout=0.1,
|
|
channel_mult=(1, 1, 2, 2, 4, 4),
|
|
num_heads=1,
|
|
num_head_channels=64,
|
|
num_heads_upsample=1,
|
|
use_scale_shift_norm=True,
|
|
resblock_updown=True,
|
|
)
|
|
|
|
superres_model.load_state_dict(ups_state_dict, strict=False)
|
|
|
|
upscale_scheduler = DDIMScheduler(
|
|
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
|
|
)
|
|
|
|
glide = Glide(
|
|
text_unet=text2im_model,
|
|
text_noise_scheduler=text_scheduler,
|
|
text_encoder=model,
|
|
tokenizer=tokenizer,
|
|
upscale_unet=superres_model,
|
|
upscale_noise_scheduler=upscale_scheduler,
|
|
)
|
|
|
|
glide.save_pretrained("./glide-base")
|