transformer-guided glide sampling
This commit is contained in:
parent
07ffe73f79
commit
d754ce5f3b
|
@ -22,8 +22,7 @@ config = CLIPTextConfig(
|
|||
use_padding_embeddings=True,
|
||||
)
|
||||
model = CLIPTextModel(config).eval()
|
||||
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
|
||||
# tokenizer.save_pretrained("./glide-base")
|
||||
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
|
||||
|
||||
hf_encoder = model.text_model
|
||||
|
||||
|
@ -52,12 +51,6 @@ for layer_idx in range(config.num_hidden_layers):
|
|||
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
|
||||
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
|
||||
|
||||
# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
# with torch.no_grad():
|
||||
# outputs = model(**inputs)
|
||||
|
||||
# model.save_pretrained("./glide-base")
|
||||
|
||||
### Convert the UNet
|
||||
|
||||
unet_model = UNetGLIDEModel(
|
||||
|
@ -73,6 +66,7 @@ unet_model = UNetGLIDEModel(
|
|||
num_heads_upsample=1,
|
||||
use_scale_shift_norm=True,
|
||||
resblock_updown=True,
|
||||
transformer_dim=512,
|
||||
)
|
||||
|
||||
unet_model.load_state_dict(state_dict, strict=False)
|
||||
|
|
|
@ -130,21 +130,37 @@ class GLIDE(DiffusionPipeline):
|
|||
self.unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
# Create a classifier-free guidance sampling function
|
||||
guidance_scale = 3.0
|
||||
|
||||
def model_fn(x_t, ts, transformer_out, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.unet(combined, ts, transformer_out, **kwargs)
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = self.noise_scheduler.sample_noise(
|
||||
(1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
|
||||
(batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
|
||||
)
|
||||
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from (
|
||||
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
transformer_out = self.text_encoder(**inputs).last_hidden_state
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
attention_mask = inputs["attention_mask"].to(torch_device)
|
||||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||
|
||||
num_timesteps = len(self.noise_scheduler)
|
||||
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
||||
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
|
||||
noise = self.noise_scheduler.sample_noise(image.shape)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out)
|
||||
noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
|
|
|
@ -9,6 +9,6 @@ generator = generator.manual_seed(0)
|
|||
# 1. Load models
|
||||
pipeline = GLIDE.from_pretrained("fusing/glide-base")
|
||||
|
||||
img = pipeline(generator)
|
||||
img = pipeline("an oil painting of a corgi", generator)
|
||||
|
||||
print(img)
|
||||
|
|
|
@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
encoder_channels=None,
|
||||
transformer_dim=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
|
@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
encoder_channels=encoder_channels,
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
|
@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||
|
||||
ch = input_ch = int(channel_mult[0] * model_channels)
|
||||
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
|
||||
self._feature_size = ch
|
||||
|
@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=encoder_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=encoder_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
|
@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=encoder_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
|
@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
|
@ -655,13 +653,13 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
|
|||
|
||||
emb = emb + transformer_proj.to(emb)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, transformer_out)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, transformer_out)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
other = hs.pop()
|
||||
h = torch.cat([h, other], dim=1)
|
||||
h = module(h, emb, transformer_out)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
|
|
@ -65,14 +65,14 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
|||
|
||||
if beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
||||
|
||||
|
@ -81,12 +81,12 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
|||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = np.log(
|
||||
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
||||
)
|
||||
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
|
|
Loading…
Reference in New Issue