EveryDream2trainer/plugins/textual_inversion.py

135 lines
6.6 KiB
Python
Raw Normal View History

import json
import logging
import os.path
import torch
from colorama import Fore
from plugins.plugins import BasePlugin
from train import EveryDreamTrainingState
from utils.sample_generator import clean_filename
"""
This plugin adds custom tokens to the tokenizer and trains just these tokens, with the rest of the text encoder
disabled/frozen.
token/initialization config is in textual_inversion.json, same folder as this .py file.
For pure Textual Inversion training:
"disable_textenc_training": false,
"disable_unet_training": true
(Or you could unet training on too if you want, I didn't test this.)
In optimizer.json, the following "text_encoder_freezing" section is *required*:
"text_encoder_freezing": {
"unfreeze_last_n_layers": 0,
"freeze_embeddings": false,
"freeze_final_layer_norm": true
}
In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method.
"""
class TextualInversionPlugin(BasePlugin):
def __init__(self):
path = os.path.join(os.path.dirname(__file__), "textual_inversion.json")
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
with open(path, 'rt') as f:
self.config = json.load(f)
self.this_batch_tokens = None
self.training_tokens = None
self.training_token_ids = None
self.original_text_embeddings = None
def on_model_load(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs.get('ed_state')
optimizer_config: dict = kwargs.get('optimizer_config')
def get_token_ids(t: str):
return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t))
# check for correctly configured text encoder training
num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers)
if (optimizer_config is None or
'text_encoder_freezing' not in optimizer_config or
optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or
optimizer_config['text_encoder_freezing'].get('freeze_final_layer_norm') != True or
optimizer_config['text_encoder_freezing'].get('unfreeze_last_n_layers', num_te_layers) > 0
):
required_js_fragment = {"text_encoder_freezing": {"freeze_embeddings": False, "unfreeze_last_n_layers": 0, "freeze_final_layer_norm": True}}
logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES the following json fragment in your optimizer config:{Fore.RESET}")
logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}")
raise RuntimeError("Misconfigured optimizer config")
tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1]
logging.info(
f" * Textual inversion training adding the following tokens: {tokens_to_add}")
tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add]
if any(tokens_to_overwrite):
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}")
num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add)
if num_added_tokens != len(tokens_to_add):
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer))
added_token_ids = []
input_embeddings = ed_state.text_encoder.get_input_embeddings()
for token_info in self.config['tokens']:
# get newly added token id
t = token_info['token']
token_ids = get_token_ids(t)
if len(token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}")
token_id = token_ids[0]
added_token_ids.append(token_id)
# copy initializer embedding
initializer_word = token_info['initializer_word']
initializer_word_token_ids = get_token_ids(initializer_word)
if len(initializer_word_token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs "
f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.")
initializer_word_token_id = initializer_word_token_ids[0]
initializer_embedding = input_embeddings.weight.data[initializer_word_token_id]
input_embeddings.weight.data[token_id] = initializer_embedding
overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite]
self.training_tokens = tokens_to_add + tokens_to_overwrite
self.training_token_ids = added_token_ids + overwriting_token_ids
self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone()
def on_step_start(self, **kwargs):
batch = kwargs['batch']
tokens = batch['tokens'] # a torch.stack
self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist()
def on_step_end(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
# reset the embeddings that have been touched this step, except the ones we're training, to their original state
with (torch.no_grad()):
embeddings = ed_state.text_encoder.get_input_embeddings()
embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids]
for t in embeddings_to_restore:
embeddings.weight[t] = self.original_text_embeddings[t]
def on_model_save(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
embeddings = ed_state.text_encoder.get_input_embeddings()
save_folder = kwargs['save_folder']
for token_id, token in zip(self.training_token_ids, self.training_tokens):
_save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder)
def _save_embedding(token, embedding, save_folder):
dict_to_save = {token: embedding}
token_name_safe = clean_filename(token)
ti_folder = os.path.join(save_folder, 'textual_inversions')
os.makedirs(ti_folder, exist_ok=True)
save_path = os.path.join(ti_folder, token_name_safe + '.bin')
logging.info(f"Saving textual inversion for '{token}' to {save_path}")
torch.save(dict_to_save, save_path)