Merge branch 'main' of https://github.com/huggingface/diffusers
This commit is contained in:
commit
bd9c9fbfbe
|
@ -5,7 +5,6 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import einops
|
import einops
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
@ -13,7 +12,6 @@ except:
|
||||||
print("Einops is not installed")
|
print("Einops is not installed")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
@ -107,14 +105,21 @@ class ResidualTemporalBlock(nn.Module):
|
||||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
horizon,
|
training_horizon,
|
||||||
transition_dim,
|
transition_dim,
|
||||||
cond_dim,
|
cond_dim,
|
||||||
|
predict_epsilon=False,
|
||||||
|
clip_denoised=True,
|
||||||
dim=32,
|
dim=32,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.transition_dim = transition_dim
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
self.predict_epsilon = predict_epsilon
|
||||||
|
self.clip_denoised = clip_denoised
|
||||||
|
|
||||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||||
|
@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
self.downs.append(
|
self.downs.append(
|
||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_last:
|
if not is_last:
|
||||||
horizon = horizon // 2
|
training_horizon = training_horizon // 2
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
is_last = ind >= (num_resolutions - 1)
|
is_last = ind >= (num_resolutions - 1)
|
||||||
|
@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
self.ups.append(
|
self.ups.append(
|
||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_last:
|
if not is_last:
|
||||||
horizon = horizon * 2
|
training_horizon = training_horizon * 2
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
Conv1dBlock(dim, dim, kernel_size=5),
|
Conv1dBlock(dim, dim, kernel_size=5),
|
||||||
|
@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
|
||||||
|
|
||||||
print(in_out)
|
print(in_out)
|
||||||
for dim_in, dim_out in in_out:
|
for dim_in, dim_out in in_out:
|
||||||
|
|
||||||
self.blocks.append(
|
self.blocks.append(
|
||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue