diff --git a/models/vision/glide/convert_weights.py b/models/vision/glide/convert_weights.py new file mode 100644 index 00000000..6792cdef --- /dev/null +++ b/models/vision/glide/convert_weights.py @@ -0,0 +1,60 @@ +import argparse + +import torch +from torch import nn + +from transformers import CLIPTextConfig, CLIPTextModel, GPT2Tokenizer + +# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt +state_dict = torch.load("base.pt", map_location="cpu") +state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} +config = CLIPTextConfig( + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=16, + num_attention_heads=8, + max_position_embeddings=128 +) +model = CLIPTextModel(config).eval() +tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") +tokenizer.save_pretrained("./glide-base") + +hf_encoder = model.text_model + +hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"] +hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"] +hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"] + +hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"] +hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] + +for layer_idx in range(config.num_hidden_layers): + hf_layer = hf_encoder.encoder.layers[layer_idx] + q_proj, k_proj, v_proj = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"].chunk(3, dim=0) + q_proj_bias, k_proj_bias, v_proj_bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"].chunk(3, dim=0) + + hf_layer.self_attn.q_proj.weight.data = q_proj + hf_layer.self_attn.q_proj.bias.data = q_proj_bias + hf_layer.self_attn.k_proj.weight.data = k_proj + hf_layer.self_attn.k_proj.bias.data = k_proj_bias + hf_layer.self_attn.v_proj.weight.data = v_proj + hf_layer.self_attn.v_proj.bias.data = v_proj_bias + + hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] + hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] + + hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"] + hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"] + hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"] + hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"] + + hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"] + hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"] + hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] + hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] + +inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") +with torch.no_grad(): + outputs = model(**inputs) + +model.save_pretrained("./glide-base") \ No newline at end of file diff --git a/models/vision/glide/modeling_vqvae.py.py b/models/vision/glide/modeling_vqvae.py.py deleted file mode 100755 index e5a0d9b4..00000000 --- a/models/vision/glide/modeling_vqvae.py.py +++ /dev/null @@ -1 +0,0 @@ -#!/usr/bin/env python3 diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index dce2dfa8..23cd4e10 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -6,6 +6,7 @@ generator = torch.Generator() generator = generator.manual_seed(0) # 1. Load models + scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base") model = UNetGLIDEModel.from_pretrained("fusing/glide-base") diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 5a3dc91e..75401118 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -5,9 +5,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..configuration_utils import Config -from ..modeling_utils import PreTrainedModel - +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin def convert_module_to_f16(l): """ @@ -388,7 +387,7 @@ class QKVAttention(nn.Module): return a.reshape(bs, -1, length) -class UNetGLIDEModel(PreTrainedModel, Config): +class UNetGLIDEModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding.