add TextualInversionLoaderPlugin

This commit is contained in:
Damian Stewart 2024-01-24 21:45:46 +13:00
parent 072c2a695a
commit 1a4ac2d339
3 changed files with 154 additions and 53 deletions

View File

@ -6,6 +6,8 @@ import torch
from colorama import Fore
import re
from transformers import CLIPTextModel, CLIPTokenizer
from plugins.plugins import BasePlugin
from train import EveryDreamTrainingState
from utils.sample_generator import clean_filename
@ -33,6 +35,67 @@ In addition, you'll need a very high LR on the TE - maybe even as high as 5e-2.
"""
class TextualInversionLoaderPlugin(BasePlugin):
def __init__(self):
path = os.path.join(os.path.dirname(__file__), "textual_inversion_loader.json")
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
with open(path, 'rt') as f:
self.config = json.load(f)
self.padding_tokens = {}
def on_model_load(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
resume_ckpt: str = kwargs['resume_ckpt']
#self.original_tokens_length = len(ed_state.tokenizer)
tokenizer = ed_state.tokenizer
token_config = self.config['tokens']
embeddings = {}
for token_info in self.config['tokens']:
token = token_info["token"]
path = token_info.get("path", None) or _get_embedding_path(resume_ckpt, token)
with open(path, "rb") as f:
embedding_dict = torch.load(f)
embedding = list(embedding_dict.values())[0]
embeddings[token] = embedding
token_info["vector_length"] = embedding.shape[0]
print(f" * Textual Inversion Loader loaded embedding with vector length {token_info['vector_length']} for token '{token}' from {path}")
training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite = (
_setup_tokens(text_encoder=ed_state.text_encoder, tokenizer=ed_state.tokenizer, token_infos=token_config))
self.padding_tokens = padding_tokens
input_embeddings = ed_state.text_encoder.get_input_embeddings()
for token_info in self.config['tokens']:
token = token_info["token"]
vector_length = token_info["vector_length"]
trigger_and_padding_tokens = [token] + padding_tokens[token]
embedding = embeddings[token]
for i in range(vector_length):
token_ids = _get_token_ids(tokenizer, trigger_and_padding_tokens[i])
token_id = token_ids[0]
input_embeddings.weight.data[token_id] = embedding[i]
def transform_caption(self, caption:str) -> str:
return self.expand_trigger_tokens(caption)
def modify_sample_prompt(self, prompt: str) -> str:
return self.expand_trigger_tokens(prompt)
def expand_trigger_tokens(self, caption: str) -> str:
tokens = self.config['tokens']
# for multi-vector tokens, replace the trigger token with a padded sequence of the correct length.
# eg "hat*" with vector length 3 -> "hat* hat*_pad!!!_1 hat*_pad!!!_2"
for t in tokens:
trigger = t['token']
replacement = " ".join([trigger] + self.padding_tokens[trigger])
caption = re.sub(trigger, replacement, caption)
return caption
class TextualInversionPlugin(BasePlugin):
@ -50,15 +113,14 @@ class TextualInversionPlugin(BasePlugin):
def on_model_load(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs.get('ed_state')
def get_token_ids(t: str):
return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t))
tokenizer = ed_state.tokenizer
# check for correctly configured text encoder training
disable_unet_training: bool = kwargs.get('disable_unet_training')
disable_textenc_training: bool = kwargs.get('disable_textenc_training')
#if not disable_unet_training or disable_textenc_training:
# logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}")
# raise RuntimeError("Unet training must be disabled and text encoder training enabled")
if not disable_unet_training or disable_textenc_training:
logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES {Fore.RESET}\"disable_unet_training\": true{Fore.LIGHTRED_EX} and {Fore.RESET}\"disable_textenc_training\": false{Fore.LIGHTRED_EX} in your train.json{Fore.RESET}")
raise RuntimeError("Unet training must be disabled and text encoder training enabled")
num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers)
optimizer_config: dict = kwargs.get('optimizer_config')
if (optimizer_config is None or
@ -72,69 +134,57 @@ class TextualInversionPlugin(BasePlugin):
logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}")
raise RuntimeError("Misconfigured optimizer config")
training_tokens = set()
for token_info in self.config['tokens']:
start_token = token_info['token']
vector_length = token_info.get('vector_length', 1)
for token_info in self.config["tokens"]:
start_token = token_info["token"]
vector_length = token_info.get("vector_length", 1)
print(f" * Textual Inversion training on '{start_token}' with vector length {vector_length}")
this_padding_tokens = [f"{start_token}_pad!!!_{n+1}" for n in range(vector_length-1)]
self.padding_tokens[start_token] = this_padding_tokens
training_tokens.update([start_token] + this_padding_tokens)
tokens_to_add = [t for t in training_tokens if len(get_token_ids(t))>1]
training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite = (
_setup_tokens(text_encoder=ed_state.text_encoder, tokenizer=ed_state.tokenizer, token_infos=self.config['tokens']))
self.padding_tokens = padding_tokens
for trigger_token, padding_tokens in padding_tokens.items():
this_padding_token_ids = [_get_token_ids(tokenizer, t)[0] for t in padding_tokens]
self.padding_token_ids[trigger_token] = this_padding_token_ids
logging.info(
f" * Textual inversion training adding the following tokens: {sorted(tokens_to_add)}")
tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add]
if any(tokens_to_overwrite):
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}")
num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add)
if num_added_tokens != len(tokens_to_add):
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer))
added_token_ids = []
for token in tokens_to_add:
token_ids = get_token_ids(token)
if len(token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {token}, found {len(token_ids)}")
token_id = token_ids[0]
added_token_ids.append(token_id)
for trigger_token, padding_tokens in self.padding_tokens.items():
this_padding_token_ids = [get_token_ids(t)[0] for t in padding_tokens]
self.padding_token_ids[trigger_token] = this_padding_token_ids
# copy initializer embedding
input_embeddings = ed_state.text_encoder.get_input_embeddings()
for token_info in self.config['tokens']:
vector_length = token_info.get('vector_length', 1)
# make sure it's very long
initializer_text = " ".join([token_info['initializer']] * vector_length)
with torch.no_grad():
initializer_token_ids_full = ed_state.tokenizer(initializer_text,
truncation=True,
padding="max_length",
max_length=ed_state.tokenizer.model_max_length,
).input_ids
initializer_embedding_full = ed_state.text_encoder(
torch.tensor(initializer_token_ids_full, device=ed_state.text_encoder.device).unsqueeze(0), output_hidden_states=True
).last_hidden_state
initializer_embedding = initializer_embedding_full[0][1:vector_length+1]
initializer_text = None if token_info.get('random_initializer', True) else " ".join([token_info['initializer']] * vector_length)
if initializer_text is None:
reference = input_embeddings.weight[0]
embedding_length = reference.shape[0]
initializer_embedding = torch.rand([vector_length, embedding_length],
dtype=reference.dtype,
device=reference.device) * 0.1 - 0.05
else:
with torch.no_grad():
initializer_token_ids_full = ed_state.tokenizer(initializer_text,
truncation=True,
padding="max_length",
max_length=ed_state.tokenizer.model_max_length,
).input_ids
initializer_embedding_full = ed_state.text_encoder(
torch.tensor(initializer_token_ids_full, device=ed_state.text_encoder.device).unsqueeze(0), output_hidden_states=True
).last_hidden_state
initializer_embedding = initializer_embedding_full[0][1:vector_length+1]
trigger_token = token_info['token']
trigger_and_padding_tokens = [trigger_token] + self.padding_tokens[trigger_token]
for i in range(vector_length):
token_ids = get_token_ids(trigger_and_padding_tokens[i])
token_ids = _get_token_ids(tokenizer, trigger_and_padding_tokens[i])
token_id = token_ids[0]
# don't clobber trained embeddings when resuming
if token_id in tokens_to_add:
input_embeddings.weight.data[token_id] = initializer_embedding[i]
input_embeddings.weight.data[token_id] = initializer_embedding[i]
overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite]
self.training_tokens = tokens_to_add + tokens_to_overwrite
self.training_token_ids = added_token_ids + overwriting_token_ids
self.training_token_ids = [_get_token_ids(tokenizer, t)[0] for t in self.training_tokens]
# get indices of non-training tokens (ie tokens whose grads should be reset to 0 every step)
total_len = len(ed_state.text_encoder.get_input_embeddings().weight)
@ -188,10 +238,44 @@ class TextualInversionPlugin(BasePlugin):
def _save_embedding(token, embedding, save_folder):
dict_to_save = {token: embedding}
token_name_safe = clean_filename(token)
ti_folder = os.path.join(save_folder, 'textual_inversions')
os.makedirs(ti_folder, exist_ok=True)
save_path = os.path.join(ti_folder, token_name_safe + '.bin')
save_path = _get_embedding_path(save_folder, token)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
logging.info(f"Saving textual inversion for '{token}' to {save_path}")
torch.save(dict_to_save, save_path)
def _get_embedding_path(save_folder: str, token: str) -> str:
token_name_safe = clean_filename(token)
ti_folder = os.path.join(save_folder, 'textual_inversions')
return os.path.join(ti_folder, token_name_safe + '.bin')
def _setup_tokens(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, token_infos: list[dict]) -> tuple[set, dict, list, list]:
training_tokens = set()
padding_tokens = {}
for token_info in token_infos:
start_token = token_info['token']
vector_length = token_info.get('vector_length', 1)
this_padding_tokens = [f"{start_token}_pad!!!_{n + 1}" for n in range(vector_length - 1)]
padding_tokens[start_token] = this_padding_tokens
training_tokens.update([start_token] + this_padding_tokens)
tokens_to_add = [t for t in training_tokens if len(_get_token_ids(tokenizer, t)) > 1]
tokens_to_overwrite = [t for t in training_tokens if t not in tokens_to_add]
num_added_tokens = tokenizer.add_tokens(tokens_to_add)
if num_added_tokens != len(tokens_to_add):
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
text_encoder.resize_token_embeddings(len(tokenizer))
added_token_ids = []
for token in tokens_to_add:
token_ids = _get_token_ids(tokenizer, token)
if len(token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {token}, found {len(token_ids)}")
token_id = token_ids[0]
added_token_ids.append(token_id)
return training_tokens, padding_tokens, tokens_to_add, tokens_to_overwrite
def _get_token_ids(tokenizer, t: str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(t))

View File

@ -0,0 +1,16 @@
{
"documentation": {
"tokens": {
"token": "the trigger token (word or phrase). whenever this word or phrase appears in image captions, the embedding will be trained.",
"path": "(optional) /path/to/embedding.bin. If omitted, tries to load an embedding from the resume_ckpt diffusers folder, textual_inversions/<token>.bin where <token> is the token"
},
"example": "the example below tries to load textual_inversions/hat*.bin from inside the resume ckpt's textual_inversion folder and textual_inversions/dancing shoes.bin from inside the model folder, and cane from the path specified."
},
"tokens": [
{ "token": "hat*" },
{ "token": "dancing shoes" },
{ "token": "cane", "path": "/workspace/embeddings/my_cane_embedding_ep30.bin"}
]
}

View File

@ -792,6 +792,7 @@ def main(args):
plugin_runner.run_on_model_load(
ed_state=EveryDreamTrainingState(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae,
optimizer=None, train_batch=None, scheduler=noise_scheduler, unet_ema=None, text_encoder_ema=None),
resume_ckpt=args.resume_ckpt,
optimizer_config=optimizer_config,
disable_unet_training=args.disable_unet_training,
disable_textenc_training=args.disable_textenc_training