Fix a small typo of a variable name (#1063)
Fix a small typo fix a typo in `models/attention.py`. weight -> width
This commit is contained in:
parent
4e59bcc680
commit
1216a3b122
|
@ -165,15 +165,15 @@ class SpatialTransformer(nn.Module):
|
|||
|
||||
def forward(self, hidden_states, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=context)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states + residual
|
||||
|
||||
|
|
Loading…
Reference in New Issue