From 11cfe0dd054926b5df81632f9e2b2a78738ccf95 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 7 Jul 2024 16:36:53 +0300 Subject: [PATCH] sd3 TI support --- modules/models/sd3/other_impls.py | 8 +++++--- modules/models/sd3/sd3_cond.py | 6 +++++- modules/sd_hijack.py | 17 ++++++++++++++++- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index f992db9bd..78c1dc687 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -5,6 +5,8 @@ import math from torch import nn from transformers import CLIPTokenizer, T5TokenizerFast +from modules import sd_hijack + ################################################################################################# ### Core/Utility @@ -110,9 +112,9 @@ class CLIPEncoder(torch.nn.Module): class CLIPEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"): super().__init__() - self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key) self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) def forward(self, input_tokens): @@ -127,7 +129,7 @@ class CLIPTextModel_(torch.nn.Module): intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l')) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index bade90ba1..325c512d5 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -40,6 +40,7 @@ CLIPG_CONFIG = { "intermediate_size": 5120, "num_attention_heads": 20, "num_hidden_layers": 32, + "textual_inversion_key": "clip_g", } T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" @@ -204,7 +205,10 @@ class SD3Cond(torch.nn.Module): self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) def encode_embedding_init_text(self, init_text, nvpt): - return torch.tensor([[0]], device=devices.device) # XXX + return self.model_lg.encode_embedding_init_text(init_text, nvpt) + + def tokenize(self, texts): + return self.model_lg.tokenize(texts) def medvram_modules(self): return [self.clip_g, self.clip_l, self.t5xxl] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d5b2989f4..0de830541 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -359,13 +359,28 @@ class EmbeddingsWithFixes(torch.nn.Module): vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec emb = devices.cond_cast_unet(vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) vecs.append(tensor) return torch.stack(vecs) +class TextualInversionEmbeddings(torch.nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + + self.embeddings = model_hijack + self.textual_inversion_key = textual_inversion_key + + @property + def wrapped(self): + return super().forward + + def forward(self, input_ids): + return EmbeddingsWithFixes.forward(self, input_ids) + + def add_circular_option_to_conv_2d(): conv2d_constructor = torch.nn.Conv2d.__init__