Merge pull request #41 from huggingface/fix_comments
[Resnets] Fix comments
This commit is contained in:
commit
814133ec9c
|
@ -167,8 +167,8 @@ class Downsample(nn.Module):
|
||||||
# class GlideUpsample(nn.Module):
|
# class GlideUpsample(nn.Module):
|
||||||
# """
|
# """
|
||||||
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
|
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
|
||||||
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
||||||
3D, then # upsampling occurs in the inner-two dimensions. #"""
|
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||||
#
|
#
|
||||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||||
# super().__init__()
|
# super().__init__()
|
||||||
|
@ -193,8 +193,8 @@ use_conv: a bool determining if a convolution is # applied. :param dims: determi
|
||||||
# class LDMUpsample(nn.Module):
|
# class LDMUpsample(nn.Module):
|
||||||
# """
|
# """
|
||||||
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
|
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
|
||||||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
||||||
3D, then # upsampling occurs in the inner-two dimensions. #"""
|
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||||
#
|
#
|
||||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||||
# super().__init__()
|
# super().__init__()
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -96,16 +94,6 @@ def zero_module(module):
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
# class TimestepBlock(nn.Module):
|
|
||||||
# """
|
|
||||||
# Any module where forward() takes timestep embeddings as a second argument. #"""
|
|
||||||
#
|
|
||||||
# @abstractmethod
|
|
||||||
# def forward(self, x, emb):
|
|
||||||
# """
|
|
||||||
# Apply the module to `x` given `emb` timestep embeddings. #"""
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
"""
|
"""
|
||||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||||
|
@ -122,101 +110,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# class ResBlock(TimestepBlock):
|
|
||||||
# """
|
|
||||||
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
|
|
||||||
channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param
|
|
||||||
out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a
|
|
||||||
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
|
|
||||||
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
|
|
||||||
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
|
|
||||||
#
|
|
||||||
# def __init__(
|
|
||||||
# self,
|
|
||||||
# channels,
|
|
||||||
# emb_channels,
|
|
||||||
# dropout,
|
|
||||||
# out_channels=None,
|
|
||||||
# use_conv=False,
|
|
||||||
# use_scale_shift_norm=False,
|
|
||||||
# dims=2,
|
|
||||||
# use_checkpoint=False,
|
|
||||||
# up=False,
|
|
||||||
# down=False,
|
|
||||||
# ):
|
|
||||||
# super().__init__()
|
|
||||||
# self.channels = channels
|
|
||||||
# self.emb_channels = emb_channels
|
|
||||||
# self.dropout = dropout
|
|
||||||
# self.out_channels = out_channels or channels
|
|
||||||
# self.use_conv = use_conv
|
|
||||||
# self.use_checkpoint = use_checkpoint
|
|
||||||
# self.use_scale_shift_norm = use_scale_shift_norm
|
|
||||||
#
|
|
||||||
# self.in_layers = nn.Sequential(
|
|
||||||
# normalization(channels, swish=1.0),
|
|
||||||
# nn.Identity(),
|
|
||||||
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# self.updown = up or down
|
|
||||||
#
|
|
||||||
# if up:
|
|
||||||
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
|
||||||
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
|
||||||
# elif down:
|
|
||||||
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
|
||||||
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
|
||||||
# else:
|
|
||||||
# self.h_upd = self.x_upd = nn.Identity()
|
|
||||||
#
|
|
||||||
# self.emb_layers = nn.Sequential(
|
|
||||||
# nn.SiLU(),
|
|
||||||
# linear(
|
|
||||||
# emb_channels,
|
|
||||||
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# self.out_layers = nn.Sequential(
|
|
||||||
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
|
||||||
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
|
||||||
# nn.Dropout(p=dropout),
|
|
||||||
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if self.out_channels == channels:
|
|
||||||
# self.skip_connection = nn.Identity()
|
|
||||||
# elif use_conv:
|
|
||||||
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
|
||||||
# else:
|
|
||||||
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
||||||
#
|
|
||||||
# def forward(self, x, emb):
|
|
||||||
# """
|
|
||||||
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
|
|
||||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #"""
|
|
||||||
# if self.updown:
|
|
||||||
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
||||||
# h = in_rest(x)
|
|
||||||
# h = self.h_upd(h)
|
|
||||||
# x = self.x_upd(x)
|
|
||||||
# h = in_conv(h)
|
|
||||||
# else:
|
|
||||||
# h = self.in_layers(x)
|
|
||||||
# emb_out = self.emb_layers(emb).type(h.dtype)
|
|
||||||
# while len(emb_out.shape) < len(h.shape):
|
|
||||||
# emb_out = emb_out[..., None]
|
|
||||||
# if self.use_scale_shift_norm:
|
|
||||||
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
|
||||||
# scale, shift = torch.chunk(emb_out, 2, dim=1)
|
|
||||||
# h = out_norm(h) * (1 + scale) + shift
|
|
||||||
# h = out_rest(h)
|
|
||||||
# else:
|
|
||||||
# h = h + emb_out
|
|
||||||
# h = self.out_layers(h)
|
|
||||||
# return self.skip_connection(x) + h
|
|
||||||
|
|
||||||
|
|
||||||
class GlideUNetModel(ModelMixin, ConfigMixin):
|
class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
The full UNet model with attention and timestep embedding.
|
The full UNet model with attention and timestep embedding.
|
||||||
|
|
|
@ -36,26 +36,6 @@ class Block(torch.nn.Module):
|
||||||
return output * mask
|
return output * mask
|
||||||
|
|
||||||
|
|
||||||
# class ResnetBlock(torch.nn.Module):
|
|
||||||
# def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
|
||||||
# super(ResnetBlock, self).__init__()
|
|
||||||
# self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
|
||||||
#
|
|
||||||
# self.block1 = Block(dim, dim_out, groups=groups)
|
|
||||||
# self.block2 = Block(dim_out, dim_out, groups=groups)
|
|
||||||
# if dim != dim_out:
|
|
||||||
# self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
|
||||||
# else:
|
|
||||||
# self.res_conv = torch.nn.Identity()
|
|
||||||
#
|
|
||||||
# def forward(self, x, mask, time_emb):
|
|
||||||
# h = self.block1(x, mask)
|
|
||||||
# h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
# h = self.block2(h, mask)
|
|
||||||
# output = h + self.res_conv(x * mask)
|
|
||||||
# return output
|
|
||||||
|
|
||||||
|
|
||||||
class Residual(torch.nn.Module):
|
class Residual(torch.nn.Module):
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
super(Residual, self).__init__()
|
super(Residual, self).__init__()
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import math
|
import math
|
||||||
from abc import abstractmethod
|
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -328,7 +327,6 @@ def normalization(channels, swish=0.0):
|
||||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||||
|
|
||||||
|
|
||||||
## go
|
|
||||||
class AttentionPool2d(nn.Module):
|
class AttentionPool2d(nn.Module):
|
||||||
"""
|
"""
|
||||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||||
|
@ -359,16 +357,6 @@ class AttentionPool2d(nn.Module):
|
||||||
return x[:, :, 0]
|
return x[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
# class TimestepBlock(nn.Module):
|
|
||||||
# """
|
|
||||||
# Any module where forward() takes timestep embeddings as a second argument. #"""
|
|
||||||
#
|
|
||||||
# @abstractmethod
|
|
||||||
# def forward(self, x, emb):
|
|
||||||
# """
|
|
||||||
# Apply the module to `x` given `emb` timestep embeddings. #"""
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
"""
|
"""
|
||||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||||
|
@ -385,99 +373,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# class A_ResBlock(TimestepBlock):
|
|
||||||
# """
|
|
||||||
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
|
|
||||||
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param #
|
|
||||||
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a
|
|
||||||
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
|
|
||||||
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
|
|
||||||
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
|
|
||||||
#
|
|
||||||
# def __init__(
|
|
||||||
# self,
|
|
||||||
# channels,
|
|
||||||
# emb_channels,
|
|
||||||
# dropout,
|
|
||||||
# out_channels=None,
|
|
||||||
# use_conv=False,
|
|
||||||
# use_scale_shift_norm=False,
|
|
||||||
# dims=2,
|
|
||||||
# use_checkpoint=False,
|
|
||||||
# up=False,
|
|
||||||
# down=False,
|
|
||||||
# ):
|
|
||||||
# super().__init__()
|
|
||||||
# self.channels = channels
|
|
||||||
# self.emb_channels = emb_channels
|
|
||||||
# self.dropout = dropout
|
|
||||||
# self.out_channels = out_channels or channels
|
|
||||||
# self.use_conv = use_conv
|
|
||||||
# self.use_checkpoint = use_checkpoint
|
|
||||||
# self.use_scale_shift_norm = use_scale_shift_norm
|
|
||||||
#
|
|
||||||
# self.in_layers = nn.Sequential(
|
|
||||||
# normalization(channels),
|
|
||||||
# nn.SiLU(),
|
|
||||||
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# self.updown = up or down
|
|
||||||
#
|
|
||||||
# if up:
|
|
||||||
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
|
||||||
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
|
||||||
# elif down:
|
|
||||||
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
|
||||||
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
|
|
||||||
# else:
|
|
||||||
# self.h_upd = self.x_upd = nn.Identity()
|
|
||||||
#
|
|
||||||
# self.emb_layers = nn.Sequential(
|
|
||||||
# nn.SiLU(),
|
|
||||||
# linear(
|
|
||||||
# emb_channels,
|
|
||||||
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# self.out_layers = nn.Sequential(
|
|
||||||
# normalization(self.out_channels),
|
|
||||||
# nn.SiLU(),
|
|
||||||
# nn.Dropout(p=dropout),
|
|
||||||
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if self.out_channels == channels:
|
|
||||||
# self.skip_connection = nn.Identity()
|
|
||||||
# elif use_conv:
|
|
||||||
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
|
||||||
# else:
|
|
||||||
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
||||||
#
|
|
||||||
# def forward(self, x, emb):
|
|
||||||
# if self.updown:
|
|
||||||
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
||||||
# h = in_rest(x)
|
|
||||||
# h = self.h_upd(h)
|
|
||||||
# x = self.x_upd(x)
|
|
||||||
# h = in_conv(h)
|
|
||||||
# else:
|
|
||||||
# h = self.in_layers(x)
|
|
||||||
# emb_out = self.emb_layers(emb).type(h.dtype)
|
|
||||||
# while len(emb_out.shape) < len(h.shape):
|
|
||||||
# emb_out = emb_out[..., None]
|
|
||||||
# if self.use_scale_shift_norm:
|
|
||||||
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
|
||||||
# scale, shift = torch.chunk(emb_out, 2, dim=1)
|
|
||||||
# h = out_norm(h) * (1 + scale) + shift
|
|
||||||
# h = out_rest(h)
|
|
||||||
# else:
|
|
||||||
# h = h + emb_out
|
|
||||||
# h = self.out_layers(h)
|
|
||||||
# return self.skip_connection(x) + h
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttention(nn.Module):
|
class QKVAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
A module which performs QKV attention and splits in a different order.
|
A module which performs QKV attention and splits in a different order.
|
||||||
|
|
|
@ -73,37 +73,6 @@ class Conv1dBlock(nn.Module):
|
||||||
return self.block(x)
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
# class ResidualTemporalBlock(nn.Module):
|
|
||||||
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
|
||||||
# super().__init__()
|
|
||||||
#
|
|
||||||
# self.blocks = nn.ModuleList(
|
|
||||||
# [
|
|
||||||
# Conv1dBlock(inp_channels, out_channels, kernel_size),
|
|
||||||
# Conv1dBlock(out_channels, out_channels, kernel_size),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# self.time_mlp = nn.Sequential(
|
|
||||||
# nn.Mish(),
|
|
||||||
# nn.Linear(embed_dim, out_channels),
|
|
||||||
# RearrangeDim(),
|
|
||||||
# Rearrange("batch t -> batch t 1"),
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# self.residual_conv = (
|
|
||||||
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# def forward(self, x, t):
|
|
||||||
# """
|
|
||||||
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
|
|
||||||
out_channels x horizon ] #"""
|
|
||||||
# out = self.blocks[0](x) + self.time_mlp(t)
|
|
||||||
# out = self.blocks[1](out)
|
|
||||||
# return out + self.residual_conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -491,137 +491,6 @@ class Downsample(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# class ResnetBlockDDPMpp(nn.Module):
|
|
||||||
# """ResBlock adapted from DDPM."""
|
|
||||||
#
|
|
||||||
# def __init__(
|
|
||||||
# self,
|
|
||||||
# act,
|
|
||||||
# in_ch,
|
|
||||||
# out_ch=None,
|
|
||||||
# temb_dim=None,
|
|
||||||
# conv_shortcut=False,
|
|
||||||
# dropout=0.1,
|
|
||||||
# skip_rescale=False,
|
|
||||||
# init_scale=0.0,
|
|
||||||
# ):
|
|
||||||
# super().__init__()
|
|
||||||
# out_ch = out_ch if out_ch else in_ch
|
|
||||||
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
|
||||||
# self.Conv_0 = conv3x3(in_ch, out_ch)
|
|
||||||
# if temb_dim is not None:
|
|
||||||
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
|
||||||
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
|
||||||
# nn.init.zeros_(self.Dense_0.bias)
|
|
||||||
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
|
||||||
# self.Dropout_0 = nn.Dropout(dropout)
|
|
||||||
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
|
||||||
# if in_ch != out_ch:
|
|
||||||
# if conv_shortcut:
|
|
||||||
# self.Conv_2 = conv3x3(in_ch, out_ch)
|
|
||||||
# else:
|
|
||||||
# self.NIN_0 = NIN(in_ch, out_ch)
|
|
||||||
#
|
|
||||||
# self.skip_rescale = skip_rescale
|
|
||||||
# self.act = act
|
|
||||||
# self.out_ch = out_ch
|
|
||||||
# self.conv_shortcut = conv_shortcut
|
|
||||||
#
|
|
||||||
# def forward(self, x, temb=None):
|
|
||||||
# h = self.act(self.GroupNorm_0(x))
|
|
||||||
# h = self.Conv_0(h)
|
|
||||||
# if temb is not None:
|
|
||||||
# h += self.Dense_0(self.act(temb))[:, :, None, None]
|
|
||||||
# h = self.act(self.GroupNorm_1(h))
|
|
||||||
# h = self.Dropout_0(h)
|
|
||||||
# h = self.Conv_1(h)
|
|
||||||
# if x.shape[1] != self.out_ch:
|
|
||||||
# if self.conv_shortcut:
|
|
||||||
# x = self.Conv_2(x)
|
|
||||||
# else:
|
|
||||||
# x = self.NIN_0(x)
|
|
||||||
# if not self.skip_rescale:
|
|
||||||
# return x + h
|
|
||||||
# else:
|
|
||||||
# return (x + h) / np.sqrt(2.0)
|
|
||||||
|
|
||||||
|
|
||||||
# class ResnetBlockBigGANpp(nn.Module):
|
|
||||||
# def __init__(
|
|
||||||
# self,
|
|
||||||
# act,
|
|
||||||
# in_ch,
|
|
||||||
# out_ch=None,
|
|
||||||
# temb_dim=None,
|
|
||||||
# up=False,
|
|
||||||
# down=False,
|
|
||||||
# dropout=0.1,
|
|
||||||
# fir=False,
|
|
||||||
# fir_kernel=(1, 3, 3, 1),
|
|
||||||
# skip_rescale=True,
|
|
||||||
# init_scale=0.0,
|
|
||||||
# ):
|
|
||||||
# super().__init__()
|
|
||||||
#
|
|
||||||
# out_ch = out_ch if out_ch else in_ch
|
|
||||||
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
|
||||||
# self.up = up
|
|
||||||
# self.down = down
|
|
||||||
# self.fir = fir
|
|
||||||
# self.fir_kernel = fir_kernel
|
|
||||||
#
|
|
||||||
# self.Conv_0 = conv3x3(in_ch, out_ch)
|
|
||||||
# if temb_dim is not None:
|
|
||||||
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
|
||||||
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
|
||||||
# nn.init.zeros_(self.Dense_0.bias)
|
|
||||||
#
|
|
||||||
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
|
||||||
# self.Dropout_0 = nn.Dropout(dropout)
|
|
||||||
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
|
||||||
# if in_ch != out_ch or up or down:
|
|
||||||
# self.Conv_2 = conv1x1(in_ch, out_ch)
|
|
||||||
#
|
|
||||||
# self.skip_rescale = skip_rescale
|
|
||||||
# self.act = act
|
|
||||||
# self.in_ch = in_ch
|
|
||||||
# self.out_ch = out_ch
|
|
||||||
#
|
|
||||||
# def forward(self, x, temb=None):
|
|
||||||
# h = self.act(self.GroupNorm_0(x))
|
|
||||||
#
|
|
||||||
# if self.up:
|
|
||||||
# if self.fir:
|
|
||||||
# h = upsample_2d(h, self.fir_kernel, factor=2)
|
|
||||||
# x = upsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
# else:
|
|
||||||
# h = naive_upsample_2d(h, factor=2)
|
|
||||||
# x = naive_upsample_2d(x, factor=2)
|
|
||||||
# elif self.down:
|
|
||||||
# if self.fir:
|
|
||||||
# h = downsample_2d(h, self.fir_kernel, factor=2)
|
|
||||||
# x = downsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
# else:
|
|
||||||
# h = naive_downsample_2d(h, factor=2)
|
|
||||||
# x = naive_downsample_2d(x, factor=2)
|
|
||||||
#
|
|
||||||
# h = self.Conv_0(h)
|
|
||||||
# Add bias to each feature map conditioned on the time embedding
|
|
||||||
# if temb is not None:
|
|
||||||
# h += self.Dense_0(self.act(temb))[:, :, None, None]
|
|
||||||
# h = self.act(self.GroupNorm_1(h))
|
|
||||||
# h = self.Dropout_0(h)
|
|
||||||
# h = self.Conv_1(h)
|
|
||||||
#
|
|
||||||
# if self.in_ch != self.out_ch or self.up or self.down:
|
|
||||||
# x = self.Conv_2(x)
|
|
||||||
#
|
|
||||||
# if not self.skip_rescale:
|
|
||||||
# return x + h
|
|
||||||
# else:
|
|
||||||
# return (x + h) / np.sqrt(2.0)
|
|
||||||
|
|
||||||
|
|
||||||
class NCSNpp(ModelMixin, ConfigMixin):
|
class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
"""NCSN++ model"""
|
"""NCSN++ model"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue