From a65dd315adcab0467d652160b26a95604573530c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jun 2024 09:06:10 +0300 Subject: [PATCH] fix T5 --- modules/models/sd3/other_impls.py | 38 ++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 6e4c5d10d..d7b9b2621 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -11,6 +11,18 @@ from transformers import CLIPTokenizer, T5TokenizerFast ################################################################################################# +class AutocastLinear(nn.Linear): + """Same as usual linear layer, but casts its weights to whatever the parameter type is. + + This is different from torch.autocast in a way that float16 layer processing float32 input + will return float16 with autocast on, and float32 with this. T5 seems to be fucked + if you do it in full float16 (returning almost all zeros in the final output). + """ + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + + def attention(q, k, v, heads, mask=None): """Convenience wrapper around a basic attention operation""" b, _, dim_head = q.shape @@ -27,9 +39,9 @@ class Mlp(nn.Module): out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) self.act = act_layer - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) def forward(self, x): x = self.fc1(x) @@ -297,7 +309,6 @@ class T5XXLModel(SDClipModel): ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl ################################################################################################# - class T5XXLTokenizer(SDTokenizer): """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" def __init__(self): @@ -319,9 +330,9 @@ class T5LayerNorm(torch.nn.Module): class T5DenseGatedActDense(torch.nn.Module): def __init__(self, model_dim, ff_dim, dtype, device): super().__init__() - self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) def forward(self, x): hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") @@ -348,10 +359,10 @@ class T5Attention(torch.nn.Module): def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): super().__init__() # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) self.num_heads = num_heads self.relative_attention_bias = None if relative_attention_bias: @@ -421,11 +432,16 @@ class T5Attention(torch.nn.Module): q = self.q(x) k = self.k(x) v = self.v(x) + if self.relative_attention_bias is not None: past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) if past_bias is not None: mask = past_bias - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask) + else: + mask = None + + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None) + return self.o(out), past_bias