Merge pull request #41 from huggingface/fix_comments

[Resnets] Fix comments
This commit is contained in:
Patrick von Platen 2022-06-29 13:47:06 +02:00 committed by GitHub
commit 814133ec9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 4 additions and 398 deletions

View File

@ -167,8 +167,8 @@ class Downsample(nn.Module):
# class GlideUpsample(nn.Module): # class GlideUpsample(nn.Module):
# """ # """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param # An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If # use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
3D, then # upsampling occurs in the inner-two dimensions. #""" # 3D, then # upsampling occurs in the inner-two dimensions. #"""
# #
# def __init__(self, channels, use_conv, dims=2, out_channels=None): # def __init__(self, channels, use_conv, dims=2, out_channels=None):
# super().__init__() # super().__init__()
@ -193,8 +193,8 @@ use_conv: a bool determining if a convolution is # applied. :param dims: determi
# class LDMUpsample(nn.Module): # class LDMUpsample(nn.Module):
# """ # """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # # An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If # use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
3D, then # upsampling occurs in the inner-two dimensions. #""" # 3D, then # upsampling occurs in the inner-two dimensions. #"""
# #
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): # def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
# super().__init__() # super().__init__()

View File

@ -1,5 +1,3 @@
from abc import abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -96,16 +94,6 @@ def zero_module(module):
return module return module
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
A sequential module that passes timestep embeddings to the children that support it as an extra input. A sequential module that passes timestep embeddings to the children that support it as an extra input.
@ -122,101 +110,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
# class ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels, swish=1.0),
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# """
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #"""
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
class GlideUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.

View File

@ -36,26 +36,6 @@ class Block(torch.nn.Module):
return output * mask return output * mask
# class ResnetBlock(torch.nn.Module):
# def __init__(self, dim, dim_out, time_emb_dim, groups=8):
# super(ResnetBlock, self).__init__()
# self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
# self.block1 = Block(dim, dim_out, groups=groups)
# self.block2 = Block(dim_out, dim_out, groups=groups)
# if dim != dim_out:
# self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
# else:
# self.res_conv = torch.nn.Identity()
#
# def forward(self, x, mask, time_emb):
# h = self.block1(x, mask)
# h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
# h = self.block2(h, mask)
# output = h + self.res_conv(x * mask)
# return output
class Residual(torch.nn.Module): class Residual(torch.nn.Module):
def __init__(self, fn): def __init__(self, fn):
super(Residual, self).__init__() super(Residual, self).__init__()

View File

@ -1,5 +1,4 @@
import math import math
from abc import abstractmethod
from inspect import isfunction from inspect import isfunction
import numpy as np import numpy as np
@ -328,7 +327,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
@ -359,16 +357,6 @@ class AttentionPool2d(nn.Module):
return x[:, :, 0] return x[:, :, 0]
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
A sequential module that passes timestep embeddings to the children that support it as an extra input. A sequential module that passes timestep embeddings to the children that support it as an extra input.
@ -385,99 +373,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
# class A_ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param #
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels),
# nn.SiLU(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels),
# nn.SiLU(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
#
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
""" """
A module which performs QKV attention and splits in a different order. A module which performs QKV attention and splits in a different order.

View File

@ -73,37 +73,6 @@ class Conv1dBlock(nn.Module):
return self.block(x) return self.block(x)
# class ResidualTemporalBlock(nn.Module):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
# super().__init__()
#
# self.blocks = nn.ModuleList(
# [
# Conv1dBlock(inp_channels, out_channels, kernel_size),
# Conv1dBlock(out_channels, out_channels, kernel_size),
# ]
# )
#
# self.time_mlp = nn.Sequential(
# nn.Mish(),
# nn.Linear(embed_dim, out_channels),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# )
#
# self.residual_conv = (
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
# )
#
# def forward(self, x, t):
# """
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ] #"""
# out = self.blocks[0](x) + self.time_mlp(t)
# out = self.blocks[1](out)
# return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,

View File

@ -491,137 +491,6 @@ class Downsample(nn.Module):
return x return x
# class ResnetBlockDDPMpp(nn.Module):
# """ResBlock adapted from DDPM."""
#
# def __init__(
# self,
# act,
# in_ch,
# out_ch=None,
# temb_dim=None,
# conv_shortcut=False,
# dropout=0.1,
# skip_rescale=False,
# init_scale=0.0,
# ):
# super().__init__()
# out_ch = out_ch if out_ch else in_ch
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
# self.Conv_0 = conv3x3(in_ch, out_ch)
# if temb_dim is not None:
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
# nn.init.zeros_(self.Dense_0.bias)
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
# self.Dropout_0 = nn.Dropout(dropout)
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
# if in_ch != out_ch:
# if conv_shortcut:
# self.Conv_2 = conv3x3(in_ch, out_ch)
# else:
# self.NIN_0 = NIN(in_ch, out_ch)
#
# self.skip_rescale = skip_rescale
# self.act = act
# self.out_ch = out_ch
# self.conv_shortcut = conv_shortcut
#
# def forward(self, x, temb=None):
# h = self.act(self.GroupNorm_0(x))
# h = self.Conv_0(h)
# if temb is not None:
# h += self.Dense_0(self.act(temb))[:, :, None, None]
# h = self.act(self.GroupNorm_1(h))
# h = self.Dropout_0(h)
# h = self.Conv_1(h)
# if x.shape[1] != self.out_ch:
# if self.conv_shortcut:
# x = self.Conv_2(x)
# else:
# x = self.NIN_0(x)
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)
# class ResnetBlockBigGANpp(nn.Module):
# def __init__(
# self,
# act,
# in_ch,
# out_ch=None,
# temb_dim=None,
# up=False,
# down=False,
# dropout=0.1,
# fir=False,
# fir_kernel=(1, 3, 3, 1),
# skip_rescale=True,
# init_scale=0.0,
# ):
# super().__init__()
#
# out_ch = out_ch if out_ch else in_ch
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
# self.up = up
# self.down = down
# self.fir = fir
# self.fir_kernel = fir_kernel
#
# self.Conv_0 = conv3x3(in_ch, out_ch)
# if temb_dim is not None:
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
# nn.init.zeros_(self.Dense_0.bias)
#
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
# self.Dropout_0 = nn.Dropout(dropout)
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
# if in_ch != out_ch or up or down:
# self.Conv_2 = conv1x1(in_ch, out_ch)
#
# self.skip_rescale = skip_rescale
# self.act = act
# self.in_ch = in_ch
# self.out_ch = out_ch
#
# def forward(self, x, temb=None):
# h = self.act(self.GroupNorm_0(x))
#
# if self.up:
# if self.fir:
# h = upsample_2d(h, self.fir_kernel, factor=2)
# x = upsample_2d(x, self.fir_kernel, factor=2)
# else:
# h = naive_upsample_2d(h, factor=2)
# x = naive_upsample_2d(x, factor=2)
# elif self.down:
# if self.fir:
# h = downsample_2d(h, self.fir_kernel, factor=2)
# x = downsample_2d(x, self.fir_kernel, factor=2)
# else:
# h = naive_downsample_2d(h, factor=2)
# x = naive_downsample_2d(x, factor=2)
#
# h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# if temb is not None:
# h += self.Dense_0(self.act(temb))[:, :, None, None]
# h = self.act(self.GroupNorm_1(h))
# h = self.Dropout_0(h)
# h = self.Conv_1(h)
#
# if self.in_ch != self.out_ch or self.up or self.down:
# x = self.Conv_2(x)
#
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)
class NCSNpp(ModelMixin, ConfigMixin): class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model""" """NCSN++ model"""