Renamed x -> hidden_states in resnet.py (#676)

renamed x to hidden_states
This commit is contained in:
Partho 2022-09-30 00:49:09 +05:30 committed by GitHub
parent 3dacbb94ca
commit a7058f42e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 21 deletions

View File

@ -34,21 +34,21 @@ class Upsample2D(nn.Module):
else: else:
self.Conv2d_0 = conv self.Conv2d_0 = conv
def forward(self, x): def forward(self, hidden_states):
assert x.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose: if self.use_conv_transpose:
return self.conv(x) return self.conv(hidden_states)
x = F.interpolate(x, scale_factor=2.0, mode="nearest") hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
x = self.conv(x) hidden_states = self.conv(hidden_states)
else: else:
x = self.Conv2d_0(x) hidden_states = self.Conv2d_0(hidden_states)
return x return hidden_states
class Downsample2D(nn.Module): class Downsample2D(nn.Module):
@ -84,16 +84,16 @@ class Downsample2D(nn.Module):
else: else:
self.conv = conv self.conv = conv
def forward(self, x): def forward(self, hidden_states):
assert x.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert x.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
x = self.conv(x) hidden_states = self.conv(hidden_states)
return x return hidden_states
class FirUpsample2D(nn.Module): class FirUpsample2D(nn.Module):
@ -174,12 +174,12 @@ class FirUpsample2D(nn.Module):
return x return x
def forward(self, x): def forward(self, hidden_states):
if self.use_conv: if self.use_conv:
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel) height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2) height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return height return height
@ -236,14 +236,14 @@ class FirDownsample2D(nn.Module):
return x return x
def forward(self, x): def forward(self, hidden_states):
if self.use_conv: if self.use_conv:
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1) hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2) hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return x return hidden_states
class ResnetBlock2D(nn.Module): class ResnetBlock2D(nn.Module):