diff --git a/models/vision/glide/run_glide.py b/models/vision/glide/run_glide.py index e69de29b..dce2dfa8 100644 --- a/models/vision/glide/run_glide.py +++ b/models/vision/glide/run_glide.py @@ -0,0 +1,16 @@ +import torch +from .modeling_glide import GLIDE +from diffusers import UNetGLIDEModel, GaussianDDPMScheduler + +generator = torch.Generator() +generator = generator.manual_seed(0) + +# 1. Load models +scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base") +model = UNetGLIDEModel.from_pretrained("fusing/glide-base") + +pipeline = GLIDE(model, scheduler) + +img = pipeline(generator) + +print(img) diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index b7b351ba..5a3dc91e 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -1,10 +1,13 @@ import math from abc import abstractmethod -import torch as th +import torch import torch.nn as nn import torch.nn.functional as F +from ..configuration_utils import Config +from ..modeling_utils import PreTrainedModel + def convert_module_to_f16(l): """ @@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000): :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 - freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( device=timesteps.device ) args = timesteps[:, None].float() * freqs[None] - embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding @@ -298,7 +301,7 @@ class ResBlock(TimestepBlock): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) + scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: @@ -376,16 +379,16 @@ class QKVAttention(nn.Module): if encoder_kv is not None: assert encoder_kv.shape[1] == self.n_heads * ch * 2 ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) - k = th.cat([ek, k], dim=-1) - v = th.cat([ev, v], dim=-1) + k = torch.cat([ek, k], dim=-1) + v = torch.cat([ev, v], dim=-1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) -class UNetGLIDEModel(nn.Module): +class UNetGLIDEModel(PreTrainedModel, Config): """ The full UNet model with attention and timestep embedding. @@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module): encoder_channels=None, ): super().__init__() + self.register( + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + dropout=dropout, + channel_mult=channel_mult, + conv_resample=conv_resample, + dims=dims, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + encoder_channels=encoder_channels, + ) if num_heads_upsample == -1: num_heads_upsample = num_heads @@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module): self.channel_mult = channel_mult self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module): hs.append(h) h = self.middle_block(h, emb) for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) + h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) h = h.type(x.dtype) return self.out(h) diff --git a/src/diffusers/schedulers/gaussian_ddpm.py b/src/diffusers/schedulers/gaussian_ddpm.py index 38a93a6c..4fcdfdf2 100644 --- a/src/diffusers/schedulers/gaussian_ddpm.py +++ b/src/diffusers/schedulers/gaussian_ddpm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import math from torch import nn from ..configuration_utils import ConfigMixin @@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end): return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float64) + + class GaussianDDPMScheduler(nn.Module, ConfigMixin): config_name = SAMPLING_CONFIG_NAME @@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): if beta_schedule == "linear": betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + elif beta_schedule == "squaredcos_cap_v2": + # GLIDE cosine schedule + betas = betas_for_alpha_bar( + timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")