Merge pull request #24 from hasuwoof/feature/extended_clip
Implement CLIP extensions (training + inference)
This commit is contained in:
commit
9581fbc226
|
@ -67,6 +67,10 @@ model:
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
penultimate: True
|
||||||
|
extended_mode: True
|
||||||
|
max_chunks: 3
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
|
|
|
@ -5,6 +5,7 @@ import clip
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import kornia
|
import kornia
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
|
|
||||||
|
@ -136,13 +137,15 @@ class SpatialRescaler(nn.Module):
|
||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True):
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
||||||
|
self.extended_mode = extended_mode
|
||||||
|
self.max_chunks = max_chunks
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
|
@ -150,20 +153,59 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def forward(self, text):
|
def transform(self, tokens):
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=True)
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=True)
|
||||||
|
|
||||||
if self.penultimate:
|
if self.penultimate:
|
||||||
z = outputs.hidden_states[-2] # simple enough
|
z = outputs.hidden_states[-2] # simple enough
|
||||||
z = self.transformer.text_model.final_layer_norm(z)
|
z = self.transformer.text_model.final_layer_norm(z)
|
||||||
else:
|
else:
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
if self.extended_mode:
|
||||||
|
max_standard_tokens = self.max_length - 2
|
||||||
|
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
# get the max length aligned to chunk size.
|
||||||
|
max_len = np.ceil(max([len(x) for x in batch_encoding["input_ids"]]) / max_standard_tokens).astype(int).item() * max_standard_tokens
|
||||||
|
if max_len > max_standard_tokens:
|
||||||
|
z = None
|
||||||
|
|
||||||
|
for index, x in enumerate(batch_encoding["input_ids"]):
|
||||||
|
if len(x) < max_len:
|
||||||
|
# pad all tokens to the longest sentence/sequence, maybe find a torch method that can do this?
|
||||||
|
batch_encoding["input_ids"][index] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
|
||||||
|
|
||||||
|
batch_t = torch.tensor(batch_encoding["input_ids"])
|
||||||
|
# process the tensors in vertically sliced chunks
|
||||||
|
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
|
||||||
|
|
||||||
|
if z is None:
|
||||||
|
z = self.transform(chunk.to(self.device))
|
||||||
|
else:
|
||||||
|
z = torch.cat((z, self.transform(chunk.to(self.device))), dim=-2)
|
||||||
|
|
||||||
|
return z
|
||||||
|
else:
|
||||||
|
chunk = batch_encoding['input_ids']
|
||||||
|
for i, x in enumerate(chunk):
|
||||||
|
chunk[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
|
||||||
|
return self.transform(torch.asarray(chunk).to(self.device))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# default behavior
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
|
||||||
|
return self.transform(tokens)
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
|
@ -145,6 +145,17 @@ model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckp
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = model.half().to(device)
|
model = model.half().to(device)
|
||||||
|
|
||||||
|
def reshape_c_uc(c, uc):
|
||||||
|
# I have no idea how to generate an empty tensor that's valid for the model,
|
||||||
|
# so I'm gonna just pass in an empty prompt and hope it works!
|
||||||
|
padding = model.get_learned_conditioning(["" for _ in range(c.shape[0])])
|
||||||
|
while c.shape[1] != uc.shape[1]:
|
||||||
|
if c.shape[1] > uc.shape[1]:
|
||||||
|
uc = torch.cat([uc, padding], dim=1)
|
||||||
|
else:
|
||||||
|
c = torch.cat([c, padding], dim=1)
|
||||||
|
return c, uc
|
||||||
|
|
||||||
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -202,6 +213,8 @@ def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
|
if uc is not None:
|
||||||
|
c, uc = reshape_c_uc(c, uc)
|
||||||
if sampler == 'k_lms':
|
if sampler == 'k_lms':
|
||||||
sigmas = model_wrap.get_sigmas(ddim_steps)
|
sigmas = model_wrap.get_sigmas(ddim_steps)
|
||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||||
|
|
Loading…
Reference in New Issue