From d9cc0910c8aca481f294009526897152901c32b9 Mon Sep 17 00:00:00 2001 From: Alexander Ljungberg Date: Tue, 6 Jun 2023 21:45:30 +0100 Subject: [PATCH] Fix upcast attention dtype error. Without this fix, enabling the "Upcast cross attention layer to float32" option while also using `--opt-sdp-attention` breaks generation with an error: ``` File "/ext3/automatic1111/stable-diffusion-webui/modules/sd_hijack_optimizations.py", line 612, in sdp_attnblock_forward out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead. ``` The fix is to make sure to upcast the value tensor too. --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 80e48a422..6464ca8ef 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -605,7 +605,7 @@ def sdp_attnblock_forward(self, x): q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) dtype = q.dtype if shared.opts.upcast_attn: - q, k = q.float(), k.float() + q, k, v = q.float(), k.float(), v.float() q = q.contiguous() k = k.contiguous() v = v.contiguous()