fix T5
This commit is contained in:
parent
34b4443cc3
commit
a65dd315ad
|
@ -11,6 +11,18 @@ from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
#################################################################################################
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class AutocastLinear(nn.Linear):
|
||||||
|
"""Same as usual linear layer, but casts its weights to whatever the parameter type is.
|
||||||
|
|
||||||
|
This is different from torch.autocast in a way that float16 layer processing float32 input
|
||||||
|
will return float16 with autocast on, and float32 with this. T5 seems to be fucked
|
||||||
|
if you do it in full float16 (returning almost all zeros in the final output).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v, heads, mask=None):
|
def attention(q, k, v, heads, mask=None):
|
||||||
"""Convenience wrapper around a basic attention operation"""
|
"""Convenience wrapper around a basic attention operation"""
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
|
@ -27,9 +39,9 @@ class Mlp(nn.Module):
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
hidden_features = hidden_features or in_features
|
hidden_features = hidden_features or in_features
|
||||||
|
|
||||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||||
self.act = act_layer
|
self.act = act_layer
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
|
@ -297,7 +309,6 @@ class T5XXLModel(SDClipModel):
|
||||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||||
#################################################################################################
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
class T5XXLTokenizer(SDTokenizer):
|
class T5XXLTokenizer(SDTokenizer):
|
||||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -319,9 +330,9 @@ class T5LayerNorm(torch.nn.Module):
|
||||||
class T5DenseGatedActDense(torch.nn.Module):
|
class T5DenseGatedActDense(torch.nn.Module):
|
||||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||||
|
@ -348,10 +359,10 @@ class T5Attention(torch.nn.Module):
|
||||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||||
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.relative_attention_bias = None
|
self.relative_attention_bias = None
|
||||||
if relative_attention_bias:
|
if relative_attention_bias:
|
||||||
|
@ -421,11 +432,16 @@ class T5Attention(torch.nn.Module):
|
||||||
q = self.q(x)
|
q = self.q(x)
|
||||||
k = self.k(x)
|
k = self.k(x)
|
||||||
v = self.v(x)
|
v = self.v(x)
|
||||||
|
|
||||||
if self.relative_attention_bias is not None:
|
if self.relative_attention_bias is not None:
|
||||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||||
if past_bias is not None:
|
if past_bias is not None:
|
||||||
mask = past_bias
|
mask = past_bias
|
||||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
|
||||||
|
|
||||||
return self.o(out), past_bias
|
return self.o(out), past_bias
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue