sd3 TI support
This commit is contained in:
parent
1da4907927
commit
11cfe0dd05
|
@ -5,6 +5,8 @@ import math
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from modules import sd_hijack
|
||||||
|
|
||||||
|
|
||||||
#################################################################################################
|
#################################################################################################
|
||||||
### Core/Utility
|
### Core/Utility
|
||||||
|
@ -110,9 +112,9 @@ class CLIPEncoder(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class CLIPEmbeddings(torch.nn.Module):
|
class CLIPEmbeddings(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
|
||||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, input_tokens):
|
def forward(self, input_tokens):
|
||||||
|
@ -127,7 +129,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
|
||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
||||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ CLIPG_CONFIG = {
|
||||||
"intermediate_size": 5120,
|
"intermediate_size": 5120,
|
||||||
"num_attention_heads": 20,
|
"num_attention_heads": 20,
|
||||||
"num_hidden_layers": 32,
|
"num_hidden_layers": 32,
|
||||||
|
"textual_inversion_key": "clip_g",
|
||||||
}
|
}
|
||||||
|
|
||||||
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||||
|
@ -204,7 +205,10 @@ class SD3Cond(torch.nn.Module):
|
||||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
return torch.tensor([[0]], device=devices.device) # XXX
|
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
return self.model_lg.tokenize(texts)
|
||||||
|
|
||||||
def medvram_modules(self):
|
def medvram_modules(self):
|
||||||
return [self.clip_g, self.clip_l, self.t5xxl]
|
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||||
|
|
|
@ -359,13 +359,28 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
emb = devices.cond_cast_unet(vec)
|
emb = devices.cond_cast_unet(vec)
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||||
|
|
||||||
vecs.append(tensor)
|
vecs.append(tensor)
|
||||||
|
|
||||||
return torch.stack(vecs)
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||||
|
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||||
|
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||||
|
|
||||||
|
self.embeddings = model_hijack
|
||||||
|
self.textual_inversion_key = textual_inversion_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wrapped(self):
|
||||||
|
return super().forward
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
return EmbeddingsWithFixes.forward(self, input_ids)
|
||||||
|
|
||||||
|
|
||||||
def add_circular_option_to_conv_2d():
|
def add_circular_option_to_conv_2d():
|
||||||
conv2d_constructor = torch.nn.Conv2d.__init__
|
conv2d_constructor = torch.nn.Conv2d.__init__
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue