update input names

This commit is contained in:
patil-suraj 2022-06-17 16:36:35 +02:00
parent 7dc71897b3
commit eef2327a47
1 changed files with 4 additions and 4 deletions

View File

@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t):
def forward(self, x, timesteps):
assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(t):
t = torch.tensor([t], dtype=torch.long, device=x.device)
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
# timestep embedding
temb = get_timestep_embedding(t, self.ch)
temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)