diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 97f3c02a..43c00fdf 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -34,21 +34,21 @@ class Upsample2D(nn.Module): else: self.Conv2d_0 = conv - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: - return self.conv(x) + return self.conv(hidden_states) - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - x = self.conv(x) + hidden_states = self.conv(hidden_states) else: - x = self.Conv2d_0(x) + hidden_states = self.Conv2d_0(hidden_states) - return x + return hidden_states class Downsample2D(nn.Module): @@ -84,16 +84,16 @@ class Downsample2D(nn.Module): else: self.conv = conv - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - assert x.shape[1] == self.channels - x = self.conv(x) + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) - return x + return hidden_states class FirUpsample2D(nn.Module): @@ -174,12 +174,12 @@ class FirUpsample2D(nn.Module): return x - def forward(self, x): + def forward(self, hidden_states): if self.use_conv: - height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return height @@ -236,14 +236,14 @@ class FirDownsample2D(nn.Module): return x - def forward(self, x): + def forward(self, hidden_states): if self.use_conv: - x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: - x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) + hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - return x + return hidden_states class ResnetBlock2D(nn.Module):