Fix upcast attention dtype error.
Without this fix, enabling the "Upcast cross attention layer to float32" option while also using `--opt-sdp-attention` breaks generation with an error: ``` File "/ext3/automatic1111/stable-diffusion-webui/modules/sd_hijack_optimizations.py", line 612, in sdp_attnblock_forward out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead. ``` The fix is to make sure to upcast the value tensor too.
This commit is contained in:
parent
baf6946e06
commit
d9cc0910c8
|
@ -605,7 +605,7 @@ def sdp_attnblock_forward(self, x):
|
||||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k, v = q.float(), k.float(), v.float()
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
|
Loading…
Reference in New Issue