diff --git a/configs/stable-diffusion/v1-finetune-4gpu.yaml b/configs/stable-diffusion/v1-finetune-4gpu.yaml index 7631aae..0efc2a6 100644 --- a/configs/stable-diffusion/v1-finetune-4gpu.yaml +++ b/configs/stable-diffusion/v1-finetune-4gpu.yaml @@ -67,6 +67,10 @@ model: cond_stage_config: target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + penultimate: True + extended_mode: True + max_chunks: 3 data: target: main.DataModuleFromConfig diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 84b793f..17152ec 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -5,6 +5,7 @@ import clip from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia +import numpy as np from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test @@ -136,13 +137,15 @@ class SpatialRescaler(nn.Module): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True): + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length - self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf + self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf + self.extended_mode = extended_mode + self.max_chunks = max_chunks self.freeze() def freeze(self): @@ -150,20 +153,59 @@ class FrozenCLIPEmbedder(AbstractEncoder): for param in self.parameters(): param.requires_grad = False - def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) + def transform(self, tokens): outputs = self.transformer(input_ids=tokens, output_hidden_states=True) if self.penultimate: - z = outputs.hidden_states[-2] # simple enough + z = outputs.hidden_states[-2] # simple enough z = self.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state - + return z + def forward(self, text): + if self.extended_mode: + max_standard_tokens = self.max_length - 2 + + batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False, + add_special_tokens=False) + + # get the max length aligned to chunk size. + max_len = np.ceil(max([len(x) for x in batch_encoding["input_ids"]]) / max_standard_tokens).astype(int).item() * max_standard_tokens + if max_len > max_standard_tokens: + z = None + + for index, x in enumerate(batch_encoding["input_ids"]): + if len(x) < max_len: + # pad all tokens to the longest sentence/sequence, maybe find a torch method that can do this? + batch_encoding["input_ids"][index] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)] + + batch_t = torch.tensor(batch_encoding["input_ids"]) + # process the tensors in vertically sliced chunks + chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)] + for chunk in chunks: + chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1) + + if z is None: + z = self.transform(chunk.to(self.device)) + else: + z = torch.cat((z, self.transform(chunk.to(self.device))), dim=-2) + + return z + else: + chunk = batch_encoding['input_ids'] + for i, x in enumerate(chunk): + chunk[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)] + return self.transform(torch.asarray(chunk).to(self.device)) + + else: + # default behavior + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + + return self.transform(tokens) + def encode(self, text): return self(text) diff --git a/scripts/txt2img_gradio.py b/scripts/txt2img_gradio.py index 1b0b862..7bb8287 100644 --- a/scripts/txt2img_gradio.py +++ b/scripts/txt2img_gradio.py @@ -145,6 +145,17 @@ model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckp device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.half().to(device) +def reshape_c_uc(c, uc): + # I have no idea how to generate an empty tensor that's valid for the model, + # so I'm gonna just pass in an empty prompt and hope it works! + padding = model.get_learned_conditioning(["" for _ in range(c.shape[0])]) + while c.shape[1] != uc.shape[1]: + if c.shape[1] > uc.shape[1]: + uc = torch.cat([uc, padding], dim=1) + else: + c = torch.cat([c, padding], dim=1) + return c, uc + def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): torch.cuda.empty_cache() @@ -202,6 +213,8 @@ def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + if uc is not None: + c, uc = reshape_c_uc(c, uc) if sampler == 'k_lms': sigmas = model_wrap.get_sigmas(ddim_steps) model_wrap_cfg = CFGDenoiser(model_wrap)