Merge branch 'main' of https://github.com/huggingface/diffusers
This commit is contained in:
commit
bd9c9fbfbe
|
@ -5,7 +5,6 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
try:
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
|
@ -13,7 +12,6 @@ except:
|
|||
print("Einops is not installed")
|
||||
pass
|
||||
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module):
|
|||
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
self,
|
||||
training_horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
predict_epsilon=False,
|
||||
clip_denoised=True,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transition_dim = transition_dim
|
||||
self.cond_dim = cond_dim
|
||||
self.predict_epsilon = predict_epsilon
|
||||
self.clip_denoised = clip_denoised
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
|
@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
|||
self.downs.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
horizon = horizon // 2
|
||||
training_horizon = training_horizon // 2
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
|||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
horizon = horizon * 2
|
||||
training_horizon = training_horizon * 2
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv1dBlock(dim, dim, kernel_size=5),
|
||||
|
@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
|||
|
||||
class TemporalValue(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
time_dim=None,
|
||||
out_dim=1,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
time_dim=None,
|
||||
out_dim=1,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
|
|||
|
||||
print(in_out)
|
||||
for dim_in, dim_out in in_out:
|
||||
|
||||
self.blocks.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue