This commit is contained in:
Patrick von Platen 2022-06-22 23:16:05 +02:00
commit bd9c9fbfbe
1 changed files with 29 additions and 25 deletions

View File

@ -5,7 +5,6 @@ import math
import torch
import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
@ -13,7 +12,6 @@ except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
@ -107,14 +105,21 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__(
self,
horizon,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
self.transition_dim = transition_dim
self.cond_dim = cond_dim
self.predict_epsilon = predict_epsilon
self.clip_denoised = clip_denoised
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon // 2
training_horizon = training_horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon * 2
training_horizon = training_horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(
nn.ModuleList(
[