EveryDream2trainer/plugins/textual_inversion.py

135 lines
6.6 KiB
Python

import json
import logging
import os.path
import torch
from colorama import Fore
from plugins.plugins import BasePlugin
from train import EveryDreamTrainingState
from utils.sample_generator import clean_filename
"""
This plugin adds custom tokens to the tokenizer and trains just these tokens, with the rest of the text encoder
disabled/frozen.
token/initialization config is in textual_inversion.json, same folder as this .py file.
For pure Textual Inversion training:
"disable_textenc_training": false,
"disable_unet_training": true
(Or you could unet training on too if you want, I didn't test this.)
In optimizer.json, the following "text_encoder_freezing" section is *required*:
"text_encoder_freezing": {
"unfreeze_last_n_layers": 0,
"freeze_embeddings": false,
"freeze_final_layer_norm": true
}
In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method.
"""
class TextualInversionPlugin(BasePlugin):
def __init__(self):
path = os.path.join(os.path.dirname(__file__), "textual_inversion.json")
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
with open(path, 'rt') as f:
self.config = json.load(f)
self.this_batch_tokens = None
self.training_tokens = None
self.training_token_ids = None
self.original_text_embeddings = None
def on_model_load(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs.get('ed_state')
optimizer_config: dict = kwargs.get('optimizer_config')
def get_token_ids(t: str):
return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t))
# check for correctly configured text encoder training
num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers)
if (optimizer_config is None or
'text_encoder_freezing' not in optimizer_config or
optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or
optimizer_config['text_encoder_freezing'].get('freeze_final_layer_norm') != True or
optimizer_config['text_encoder_freezing'].get('unfreeze_last_n_layers', num_te_layers) > 0
):
required_js_fragment = {"text_encoder_freezing": {"freeze_embeddings": False, "unfreeze_last_n_layers": 0, "freeze_final_layer_norm": True}}
logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES the following json fragment in your optimizer config:{Fore.RESET}")
logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}")
raise RuntimeError("Misconfigured optimizer config")
tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1]
logging.info(
f" * Textual inversion training adding the following tokens: {tokens_to_add}")
tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add]
if any(tokens_to_overwrite):
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}")
num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add)
if num_added_tokens != len(tokens_to_add):
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer))
added_token_ids = []
input_embeddings = ed_state.text_encoder.get_input_embeddings()
for token_info in self.config['tokens']:
# get newly added token id
t = token_info['token']
token_ids = get_token_ids(t)
if len(token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}")
token_id = token_ids[0]
added_token_ids.append(token_id)
# copy initializer embedding
initializer_word = token_info['initializer_word']
initializer_word_token_ids = get_token_ids(initializer_word)
if len(initializer_word_token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs "
f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.")
initializer_word_token_id = initializer_word_token_ids[0]
initializer_embedding = input_embeddings.weight.data[initializer_word_token_id]
input_embeddings.weight.data[token_id] = initializer_embedding
overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite]
self.training_tokens = tokens_to_add + tokens_to_overwrite
self.training_token_ids = added_token_ids + overwriting_token_ids
self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone()
def on_step_start(self, **kwargs):
batch = kwargs['batch']
tokens = batch['tokens'] # a torch.stack
self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist()
def on_step_end(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
# reset the embeddings that have been touched this step, except the ones we're training, to their original state
with (torch.no_grad()):
embeddings = ed_state.text_encoder.get_input_embeddings()
embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids]
for t in embeddings_to_restore:
embeddings.weight[t] = self.original_text_embeddings[t]
def on_model_save(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
embeddings = ed_state.text_encoder.get_input_embeddings()
save_folder = kwargs['save_folder']
for token_id, token in zip(self.training_token_ids, self.training_tokens):
_save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder)
def _save_embedding(token, embedding, save_folder):
dict_to_save = {token: embedding}
token_name_safe = clean_filename(token)
ti_folder = os.path.join(save_folder, 'textual_inversions')
os.makedirs(ti_folder, exist_ok=True)
save_path = os.path.join(ti_folder, token_name_safe + '.bin')
logging.info(f"Saving textual inversion for '{token}' to {save_path}")
torch.save(dict_to_save, save_path)