From 7b4e049eb00154219c025d20e2273f766c3bfc5f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 22 Jun 2022 14:16:53 -0400 Subject: [PATCH] adding properties, formatting --- src/diffusers/models/unet_rl.py | 54 ++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 55654dc6..4fdffd33 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,7 +5,6 @@ import math import torch import torch.nn as nn - try: import einops from einops.layers.torch import Rearrange @@ -13,7 +12,6 @@ except: print("Einops is not installed") pass - from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin @@ -106,15 +104,22 @@ class ResidualTemporalBlock(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - dim_mults=(1, 2, 4, 8), + self, + training_horizon, + transition_dim, + cond_dim, + predict_epsilon=False, + clip_denoised=True, + dim=32, + dim_mults=(1, 2, 4, 8), ): super().__init__() + self.transition_dim = transition_dim + self.cond_dim = cond_dim + self.predict_epsilon = predict_epsilon + self.clip_denoised = clip_denoised + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # print(f'[ models/temporal ] Channel dimensions: {in_out}') @@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): self.downs.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), Downsample1d(dim_out) if not is_last else nn.Identity(), ] ) ) if not is_last: - horizon = horizon // 2 + training_horizon = training_horizon // 2 mid_dim = dims[-1] - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) @@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): self.ups.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), Upsample1d(dim_in) if not is_last else nn.Identity(), ] ) ) if not is_last: - horizon = horizon * 2 + training_horizon = training_horizon * 2 self.final_conv = nn.Sequential( Conv1dBlock(dim, dim, kernel_size=5), @@ -206,14 +211,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalValue(nn.Module): def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - time_dim=None, - out_dim=1, - dim_mults=(1, 2, 4, 8), + self, + horizon, + transition_dim, + cond_dim, + dim=32, + time_dim=None, + out_dim=1, + dim_mults=(1, 2, 4, 8), ): super().__init__() @@ -232,7 +237,6 @@ class TemporalValue(nn.Module): print(in_out) for dim_in, dim_out in in_out: - self.blocks.append( nn.ModuleList( [