Merge pull request #24 from hasuwoof/feature/extended_clip

Implement CLIP extensions (training + inference)
This commit is contained in:
harubaru 2022-10-14 16:27:21 -07:00 committed by GitHub
commit 9581fbc226
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 8 deletions

View File

@ -67,6 +67,10 @@ model:
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
penultimate: True
extended_mode: True
max_chunks: 3
data:
target: main.DataModuleFromConfig

View File

@ -5,6 +5,7 @@ import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
import numpy as np
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
@ -136,13 +137,15 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
self.extended_mode = extended_mode
self.max_chunks = max_chunks
self.freeze()
def freeze(self):
@ -150,20 +153,59 @@ class FrozenCLIPEmbedder(AbstractEncoder):
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
def transform(self, tokens):
outputs = self.transformer(input_ids=tokens, output_hidden_states=True)
if self.penultimate:
z = outputs.hidden_states[-2] # simple enough
z = outputs.hidden_states[-2] # simple enough
z = self.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
return z
def forward(self, text):
if self.extended_mode:
max_standard_tokens = self.max_length - 2
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False,
add_special_tokens=False)
# get the max length aligned to chunk size.
max_len = np.ceil(max([len(x) for x in batch_encoding["input_ids"]]) / max_standard_tokens).astype(int).item() * max_standard_tokens
if max_len > max_standard_tokens:
z = None
for index, x in enumerate(batch_encoding["input_ids"]):
if len(x) < max_len:
# pad all tokens to the longest sentence/sequence, maybe find a torch method that can do this?
batch_encoding["input_ids"][index] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
batch_t = torch.tensor(batch_encoding["input_ids"])
# process the tensors in vertically sliced chunks
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
for chunk in chunks:
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
if z is None:
z = self.transform(chunk.to(self.device))
else:
z = torch.cat((z, self.transform(chunk.to(self.device))), dim=-2)
return z
else:
chunk = batch_encoding['input_ids']
for i, x in enumerate(chunk):
chunk[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
return self.transform(torch.asarray(chunk).to(self.device))
else:
# default behavior
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return self.transform(tokens)
def encode(self, text):
return self(text)

View File

@ -145,6 +145,17 @@ model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckp
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.half().to(device)
def reshape_c_uc(c, uc):
# I have no idea how to generate an empty tensor that's valid for the model,
# so I'm gonna just pass in an empty prompt and hope it works!
padding = model.get_learned_conditioning(["" for _ in range(c.shape[0])])
while c.shape[1] != uc.shape[1]:
if c.shape[1] > uc.shape[1]:
uc = torch.cat([uc, padding], dim=1)
else:
c = torch.cat([c, padding], dim=1)
return c, uc
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache()
@ -202,6 +213,8 @@ def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if uc is not None:
c, uc = reshape_c_uc(c, uc)
if sampler == 'k_lms':
sigmas = model_wrap.get_sigmas(ddim_steps)
model_wrap_cfg = CFGDenoiser(model_wrap)