This commit is contained in:
Patrick von Platen 2022-06-22 23:16:05 +02:00
commit bd9c9fbfbe
1 changed files with 29 additions and 25 deletions

View File

@ -5,7 +5,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
try: try:
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
@ -13,7 +12,6 @@ except:
print("Einops is not installed") print("Einops is not installed")
pass pass
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
@ -107,14 +105,21 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
horizon, training_horizon,
transition_dim, transition_dim,
cond_dim, cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32, dim=32,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
): ):
super().__init__() super().__init__()
self.transition_dim = transition_dim
self.cond_dim = cond_dim
self.predict_epsilon = predict_epsilon
self.clip_denoised = clip_denoised
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}') # print(f'[ models/temporal ] Channel dimensions: {in_out}')
@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.downs.append( self.downs.append(
nn.ModuleList( nn.ModuleList(
[ [
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(), Downsample1d(dim_out) if not is_last else nn.Identity(),
] ]
) )
) )
if not is_last: if not is_last:
horizon = horizon // 2 training_horizon = training_horizon // 2
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups.append( self.ups.append(
nn.ModuleList( nn.ModuleList(
[ [
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(), Upsample1d(dim_in) if not is_last else nn.Identity(),
] ]
) )
) )
if not is_last: if not is_last:
horizon = horizon * 2 training_horizon = training_horizon * 2
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5), Conv1dBlock(dim, dim, kernel_size=5),
@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
print(in_out) print(in_out)
for dim_in, dim_out in in_out: for dim_in, dim_out in in_out:
self.blocks.append( self.blocks.append(
nn.ModuleList( nn.ModuleList(
[ [