Renamed x -> hidden_states in resnet.py (#676)
renamed x to hidden_states
This commit is contained in:
parent
3dacbb94ca
commit
a7058f42e1
|
@ -34,21 +34,21 @@ class Upsample2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.Conv2d_0 = conv
|
self.Conv2d_0 = conv
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, hidden_states):
|
||||||
assert x.shape[1] == self.channels
|
assert hidden_states.shape[1] == self.channels
|
||||||
if self.use_conv_transpose:
|
if self.use_conv_transpose:
|
||||||
return self.conv(x)
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||||
|
|
||||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
if self.name == "conv":
|
if self.name == "conv":
|
||||||
x = self.conv(x)
|
hidden_states = self.conv(hidden_states)
|
||||||
else:
|
else:
|
||||||
x = self.Conv2d_0(x)
|
hidden_states = self.Conv2d_0(hidden_states)
|
||||||
|
|
||||||
return x
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Downsample2D(nn.Module):
|
class Downsample2D(nn.Module):
|
||||||
|
@ -84,16 +84,16 @@ class Downsample2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.conv = conv
|
self.conv = conv
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, hidden_states):
|
||||||
assert x.shape[1] == self.channels
|
assert hidden_states.shape[1] == self.channels
|
||||||
if self.use_conv and self.padding == 0:
|
if self.use_conv and self.padding == 0:
|
||||||
pad = (0, 1, 0, 1)
|
pad = (0, 1, 0, 1)
|
||||||
x = F.pad(x, pad, mode="constant", value=0)
|
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||||
|
|
||||||
assert x.shape[1] == self.channels
|
assert hidden_states.shape[1] == self.channels
|
||||||
x = self.conv(x)
|
hidden_states = self.conv(hidden_states)
|
||||||
|
|
||||||
return x
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FirUpsample2D(nn.Module):
|
class FirUpsample2D(nn.Module):
|
||||||
|
@ -174,12 +174,12 @@ class FirUpsample2D(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, hidden_states):
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||||
else:
|
else:
|
||||||
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
|
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||||
|
|
||||||
return height
|
return height
|
||||||
|
|
||||||
|
@ -236,14 +236,14 @@ class FirDownsample2D(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, hidden_states):
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||||
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||||
else:
|
else:
|
||||||
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
|
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||||
|
|
||||||
return x
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock2D(nn.Module):
|
class ResnetBlock2D(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue