Merge pull request #24 from hasuwoof/feature/extended_clip
Implement CLIP extensions (training + inference)
This commit is contained in:
commit
9581fbc226
|
@ -67,6 +67,10 @@ model:
|
|||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
penultimate: True
|
||||
extended_mode: True
|
||||
max_chunks: 3
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
|
|
|
@ -5,6 +5,7 @@ import clip
|
|||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
import numpy as np
|
||||
|
||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
|
||||
|
@ -136,13 +137,15 @@ class SpatialRescaler(nn.Module):
|
|||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
||||
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
||||
self.extended_mode = extended_mode
|
||||
self.max_chunks = max_chunks
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
|
@ -150,20 +153,59 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
def transform(self, tokens):
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=True)
|
||||
|
||||
if self.penultimate:
|
||||
z = outputs.hidden_states[-2] # simple enough
|
||||
z = outputs.hidden_states[-2] # simple enough
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def forward(self, text):
|
||||
if self.extended_mode:
|
||||
max_standard_tokens = self.max_length - 2
|
||||
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False,
|
||||
add_special_tokens=False)
|
||||
|
||||
# get the max length aligned to chunk size.
|
||||
max_len = np.ceil(max([len(x) for x in batch_encoding["input_ids"]]) / max_standard_tokens).astype(int).item() * max_standard_tokens
|
||||
if max_len > max_standard_tokens:
|
||||
z = None
|
||||
|
||||
for index, x in enumerate(batch_encoding["input_ids"]):
|
||||
if len(x) < max_len:
|
||||
# pad all tokens to the longest sentence/sequence, maybe find a torch method that can do this?
|
||||
batch_encoding["input_ids"][index] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
|
||||
|
||||
batch_t = torch.tensor(batch_encoding["input_ids"])
|
||||
# process the tensors in vertically sliced chunks
|
||||
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
|
||||
for chunk in chunks:
|
||||
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
|
||||
|
||||
if z is None:
|
||||
z = self.transform(chunk.to(self.device))
|
||||
else:
|
||||
z = torch.cat((z, self.transform(chunk.to(self.device))), dim=-2)
|
||||
|
||||
return z
|
||||
else:
|
||||
chunk = batch_encoding['input_ids']
|
||||
for i, x in enumerate(chunk):
|
||||
chunk[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
|
||||
return self.transform(torch.asarray(chunk).to(self.device))
|
||||
|
||||
else:
|
||||
# default behavior
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
|
||||
return self.transform(tokens)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
|
|
@ -145,6 +145,17 @@ model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckp
|
|||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.half().to(device)
|
||||
|
||||
def reshape_c_uc(c, uc):
|
||||
# I have no idea how to generate an empty tensor that's valid for the model,
|
||||
# so I'm gonna just pass in an empty prompt and hope it works!
|
||||
padding = model.get_learned_conditioning(["" for _ in range(c.shape[0])])
|
||||
while c.shape[1] != uc.shape[1]:
|
||||
if c.shape[1] > uc.shape[1]:
|
||||
uc = torch.cat([uc, padding], dim=1)
|
||||
else:
|
||||
c = torch.cat([c, padding], dim=1)
|
||||
return c, uc
|
||||
|
||||
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -202,6 +213,8 @@ def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta
|
|||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
if uc is not None:
|
||||
c, uc = reshape_c_uc(c, uc)
|
||||
if sampler == 'k_lms':
|
||||
sigmas = model_wrap.get_sigmas(ddim_steps)
|
||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||
|
|
Loading…
Reference in New Issue