update input names
This commit is contained in:
parent
7dc71897b3
commit
eef2327a47
|
@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t):
|
||||
def forward(self, x, timesteps):
|
||||
assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if not torch.is_tensor(t):
|
||||
t = torch.tensor([t], dtype=torch.long, device=x.device)
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
|
||||
|
||||
# timestep embedding
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = get_timestep_embedding(timesteps, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
|
|
Loading…
Reference in New Issue