Removing `.float()` (`autocast` in fp16 will discard this (I think)). (#495)
This commit is contained in:
parent
ab7a78e8f1
commit
7c4b38baca
|
@ -333,7 +333,7 @@ class ResnetBlock2D(nn.Module):
|
|||
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
|
||||
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
|
@ -351,7 +351,7 @@ class ResnetBlock2D(nn.Module):
|
|||
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
|
||||
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
|
Loading…
Reference in New Issue