Merge pull request #24 from hasuwoof/feature/extended_clip

Implement CLIP extensions (training + inference)
This commit is contained in:
harubaru 2022-10-14 16:27:21 -07:00 committed by GitHub
commit 9581fbc226
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 8 deletions

View File

@ -67,6 +67,10 @@ model:
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
penultimate: True
extended_mode: True
max_chunks: 3
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig

View File

@ -5,6 +5,7 @@ import clip
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import kornia import kornia
import numpy as np
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
@ -136,13 +137,15 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True): def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version) self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
self.extended_mode = extended_mode
self.max_chunks = max_chunks
self.freeze() self.freeze()
def freeze(self): def freeze(self):
@ -150,10 +153,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def forward(self, text): def transform(self, tokens):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=True) outputs = self.transformer(input_ids=tokens, output_hidden_states=True)
if self.penultimate: if self.penultimate:
@ -164,6 +164,48 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return z return z
def forward(self, text):
if self.extended_mode:
max_standard_tokens = self.max_length - 2
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False,
add_special_tokens=False)
# get the max length aligned to chunk size.
max_len = np.ceil(max([len(x) for x in batch_encoding["input_ids"]]) / max_standard_tokens).astype(int).item() * max_standard_tokens
if max_len > max_standard_tokens:
z = None
for index, x in enumerate(batch_encoding["input_ids"]):
if len(x) < max_len:
# pad all tokens to the longest sentence/sequence, maybe find a torch method that can do this?
batch_encoding["input_ids"][index] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
batch_t = torch.tensor(batch_encoding["input_ids"])
# process the tensors in vertically sliced chunks
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
for chunk in chunks:
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
if z is None:
z = self.transform(chunk.to(self.device))
else:
z = torch.cat((z, self.transform(chunk.to(self.device))), dim=-2)
return z
else:
chunk = batch_encoding['input_ids']
for i, x in enumerate(chunk):
chunk[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
return self.transform(torch.asarray(chunk).to(self.device))
else:
# default behavior
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return self.transform(tokens)
def encode(self, text): def encode(self, text):
return self(text) return self(text)

View File

@ -145,6 +145,17 @@ model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckp
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.half().to(device) model = model.half().to(device)
def reshape_c_uc(c, uc):
# I have no idea how to generate an empty tensor that's valid for the model,
# so I'm gonna just pass in an empty prompt and hope it works!
padding = model.get_learned_conditioning(["" for _ in range(c.shape[0])])
while c.shape[1] != uc.shape[1]:
if c.shape[1] > uc.shape[1]:
uc = torch.cat([uc, padding], dim=1)
else:
c = torch.cat([c, padding], dim=1)
return c, uc
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -202,6 +213,8 @@ def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if uc is not None:
c, uc = reshape_c_uc(c, uc)
if sampler == 'k_lms': if sampler == 'k_lms':
sigmas = model_wrap.get_sigmas(ddim_steps) sigmas = model_wrap.get_sigmas(ddim_steps)
model_wrap_cfg = CFGDenoiser(model_wrap) model_wrap_cfg = CFGDenoiser(model_wrap)