Fix for Unet NaNs
This commit is contained in:
parent
5ab7f213be
commit
7aab389d6f
|
@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
|
Loading…
Reference in New Issue