Merge pull request #2 from huggingface/add-glide
+ cosine schedule and unet config
This commit is contained in:
commit
7f6a36c3b1
|
@ -0,0 +1,16 @@
|
|||
import torch
|
||||
from .modeling_glide import GLIDE
|
||||
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
|
||||
# 1. Load models
|
||||
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base")
|
||||
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
|
||||
|
||||
pipeline = GLIDE(model, scheduler)
|
||||
|
||||
img = pipeline(generator)
|
||||
|
||||
print(img)
|
|
@ -1,10 +1,13 @@
|
|||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch as th
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import Config
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
"""
|
||||
|
@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
|||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
|
@ -298,7 +301,7 @@ class ResBlock(TimestepBlock):
|
|||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
|
@ -376,16 +379,16 @@ class QKVAttention(nn.Module):
|
|||
if encoder_kv is not None:
|
||||
assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
||||
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
||||
k = th.cat([ek, k], dim=-1)
|
||||
v = th.cat([ev, v], dim=-1)
|
||||
k = torch.cat([ek, k], dim=-1)
|
||||
v = torch.cat([ev, v], dim=-1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class UNetGLIDEModel(nn.Module):
|
||||
class UNetGLIDEModel(PreTrainedModel, Config):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
||||
|
@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module):
|
|||
encoder_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attention_resolutions=attention_resolutions,
|
||||
dropout=dropout,
|
||||
channel_mult=channel_mult,
|
||||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
encoder_channels=encoder_channels,
|
||||
)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module):
|
|||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module):
|
|||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
|
|||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
|||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
|
|
Loading…
Reference in New Issue