Renamed x -> hidden_states in resnet.py (#676)
renamed x to hidden_states
This commit is contained in:
parent
3dacbb94ca
commit
a7058f42e1
|
@ -34,21 +34,21 @@ class Upsample2D(nn.Module):
|
|||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
x = self.conv(x)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return x
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
|
@ -84,16 +84,16 @@ class Downsample2D(nn.Module):
|
|||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert x.shape[1] == self.channels
|
||||
x = self.conv(x)
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return x
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
|
@ -174,12 +174,12 @@ class FirUpsample2D(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, hidden_states):
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
|
||||
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return height
|
||||
|
||||
|
@ -236,14 +236,14 @@ class FirDownsample2D(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, hidden_states):
|
||||
if self.use_conv:
|
||||
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return x
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue