fix T5
This commit is contained in:
parent
34b4443cc3
commit
a65dd315ad
|
@ -11,6 +11,18 @@ from transformers import CLIPTokenizer, T5TokenizerFast
|
|||
#################################################################################################
|
||||
|
||||
|
||||
class AutocastLinear(nn.Linear):
|
||||
"""Same as usual linear layer, but casts its weights to whatever the parameter type is.
|
||||
|
||||
This is different from torch.autocast in a way that float16 layer processing float32 input
|
||||
will return float16 with autocast on, and float32 with this. T5 seems to be fucked
|
||||
if you do it in full float16 (returning almost all zeros in the final output).
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
|
@ -27,9 +39,9 @@ class Mlp(nn.Module):
|
|||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
@ -297,7 +309,6 @@ class T5XXLModel(SDClipModel):
|
|||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
def __init__(self):
|
||||
|
@ -319,9 +330,9 @@ class T5LayerNorm(torch.nn.Module):
|
|||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
|
@ -348,10 +359,10 @@ class T5Attention(torch.nn.Module):
|
|||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
|
@ -421,11 +432,16 @@ class T5Attention(torch.nn.Module):
|
|||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
if past_bias is not None:
|
||||
mask = past_bias
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
|
||||
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue