135 lines
6.6 KiB
Python
135 lines
6.6 KiB
Python
import json
|
|
import logging
|
|
import os.path
|
|
|
|
import torch
|
|
from colorama import Fore
|
|
|
|
from plugins.plugins import BasePlugin
|
|
from train import EveryDreamTrainingState
|
|
from utils.sample_generator import clean_filename
|
|
|
|
"""
|
|
This plugin adds custom tokens to the tokenizer and trains just these tokens, with the rest of the text encoder
|
|
disabled/frozen.
|
|
|
|
token/initialization config is in textual_inversion.json, same folder as this .py file.
|
|
|
|
For pure Textual Inversion training:
|
|
"disable_textenc_training": false,
|
|
"disable_unet_training": true
|
|
(Or you could unet training on too if you want, I didn't test this.)
|
|
|
|
|
|
In optimizer.json, the following "text_encoder_freezing" section is *required*:
|
|
"text_encoder_freezing": {
|
|
"unfreeze_last_n_layers": 0,
|
|
"freeze_embeddings": false,
|
|
"freeze_final_layer_norm": true
|
|
}
|
|
In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method.
|
|
|
|
"""
|
|
|
|
class TextualInversionPlugin(BasePlugin):
|
|
|
|
def __init__(self):
|
|
path = os.path.join(os.path.dirname(__file__), "textual_inversion.json")
|
|
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
|
|
with open(path, 'rt') as f:
|
|
self.config = json.load(f)
|
|
self.this_batch_tokens = None
|
|
self.training_tokens = None
|
|
self.training_token_ids = None
|
|
self.original_text_embeddings = None
|
|
|
|
def on_model_load(self, **kwargs):
|
|
ed_state: EveryDreamTrainingState = kwargs.get('ed_state')
|
|
optimizer_config: dict = kwargs.get('optimizer_config')
|
|
def get_token_ids(t: str):
|
|
return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t))
|
|
|
|
# check for correctly configured text encoder training
|
|
num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers)
|
|
if (optimizer_config is None or
|
|
'text_encoder_freezing' not in optimizer_config or
|
|
optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or
|
|
optimizer_config['text_encoder_freezing'].get('freeze_final_layer_norm') != True or
|
|
optimizer_config['text_encoder_freezing'].get('unfreeze_last_n_layers', num_te_layers) > 0
|
|
):
|
|
required_js_fragment = {"text_encoder_freezing": {"freeze_embeddings": False, "unfreeze_last_n_layers": 0, "freeze_final_layer_norm": True}}
|
|
logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES the following json fragment in your optimizer config:{Fore.RESET}")
|
|
logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}")
|
|
raise RuntimeError("Misconfigured optimizer config")
|
|
|
|
tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1]
|
|
logging.info(
|
|
f" * Textual inversion training adding the following tokens: {tokens_to_add}")
|
|
tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add]
|
|
if any(tokens_to_overwrite):
|
|
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}")
|
|
|
|
num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add)
|
|
if num_added_tokens != len(tokens_to_add):
|
|
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
|
|
ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer))
|
|
|
|
added_token_ids = []
|
|
input_embeddings = ed_state.text_encoder.get_input_embeddings()
|
|
for token_info in self.config['tokens']:
|
|
# get newly added token id
|
|
t = token_info['token']
|
|
token_ids = get_token_ids(t)
|
|
if len(token_ids) != 1:
|
|
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}")
|
|
token_id = token_ids[0]
|
|
added_token_ids.append(token_id)
|
|
|
|
# copy initializer embedding
|
|
initializer_word = token_info['initializer_word']
|
|
initializer_word_token_ids = get_token_ids(initializer_word)
|
|
if len(initializer_word_token_ids) != 1:
|
|
raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs "
|
|
f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.")
|
|
initializer_word_token_id = initializer_word_token_ids[0]
|
|
initializer_embedding = input_embeddings.weight.data[initializer_word_token_id]
|
|
input_embeddings.weight.data[token_id] = initializer_embedding
|
|
|
|
overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite]
|
|
self.training_tokens = tokens_to_add + tokens_to_overwrite
|
|
self.training_token_ids = added_token_ids + overwriting_token_ids
|
|
self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone()
|
|
|
|
|
|
def on_step_start(self, **kwargs):
|
|
batch = kwargs['batch']
|
|
tokens = batch['tokens'] # a torch.stack
|
|
self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist()
|
|
|
|
def on_step_end(self, **kwargs):
|
|
ed_state: EveryDreamTrainingState = kwargs['ed_state']
|
|
|
|
# reset the embeddings that have been touched this step, except the ones we're training, to their original state
|
|
with (torch.no_grad()):
|
|
embeddings = ed_state.text_encoder.get_input_embeddings()
|
|
embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids]
|
|
for t in embeddings_to_restore:
|
|
embeddings.weight[t] = self.original_text_embeddings[t]
|
|
|
|
def on_model_save(self, **kwargs):
|
|
ed_state: EveryDreamTrainingState = kwargs['ed_state']
|
|
embeddings = ed_state.text_encoder.get_input_embeddings()
|
|
save_folder = kwargs['save_folder']
|
|
for token_id, token in zip(self.training_token_ids, self.training_tokens):
|
|
_save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder)
|
|
|
|
def _save_embedding(token, embedding, save_folder):
|
|
dict_to_save = {token: embedding}
|
|
token_name_safe = clean_filename(token)
|
|
ti_folder = os.path.join(save_folder, 'textual_inversions')
|
|
os.makedirs(ti_folder, exist_ok=True)
|
|
save_path = os.path.join(ti_folder, token_name_safe + '.bin')
|
|
logging.info(f"Saving textual inversion for '{token}' to {save_path}")
|
|
torch.save(dict_to_save, save_path)
|
|
|