initial implementation of the_acculmunator

This commit is contained in:
Damian Stewart 2023-10-22 19:26:35 +02:00
parent ada6037463
commit 26a1475f0c
4 changed files with 204 additions and 8 deletions

53
plugins/accumulnator.py Normal file
View File

@ -0,0 +1,53 @@
import json
import logging
import os
from plugins.plugins import BasePlugin
class Accumulnator(BasePlugin):
def __init__(self):
path = os.path.join(os.path.dirname(__file__), "accumulnator.json")
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
with open(path, 'rt') as f:
config = json.load(f)
begin_epoch = config['begin_epoch']
begin_grad_accum = config['begin_grad_accum']
end_epoch = config['end_epoch']
end_grad_accum = config['end_grad_accum']
accums_per_epoch = {}
for i in range(begin_epoch):
accums_per_epoch[i] = begin_grad_accum
grad_accum_step = (end_grad_accum-begin_grad_accum)/(end_epoch-begin_epoch)
for i in range(end_grad_accum-begin_grad_accum):
grad_accum = round(grad_accum_step * i)
accums_per_epoch[i+begin_epoch] = grad_accum
self.per_epoch_grad_accum = accums_per_epoch
def on_epoch_end(self, **kwargs):
just_finished_epoch = kwargs['epoch']
epoch = just_finished_epoch + 1
grad_accum = self.per_epoch_grad_accum.get(epoch)
if grad_accum is None:
logging.warning(f" * Acculmunator has no grad_accum setting for epoch {epoch} - leaving as-is")
else:
logging.info(f" * Acculmunator setting grad_accum for epoch {epoch} to {grad_accum}")
arg_update_callback = kwargs['arg_update_callback']
arg_update_callback('grad_accum', grad_accum)
def _get_update_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1:
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only
return [epoch_length_steps-1]
else:
return []
else:
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training:
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
validate_every_n_steps = epoch_length_steps / num_divisions
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]

View File

@ -0,0 +1,134 @@
import json
import logging
import os.path
import torch
from colorama import Fore
from plugins.plugins import BasePlugin
from train import EveryDreamTrainingState
from utils.sample_generator import clean_filename
"""
This plugin adds custom tokens to the tokenizer and trains just these tokens, with the rest of the text encoder
disabled/frozen.
token/initialization config is in textual_inversion.json, same folder as this .py file.
For pure Textual Inversion training:
"disable_textenc_training": false,
"disable_unet_training": true
(Or you could unet training on too if you want, I didn't test this.)
In optimizer.json, the following "text_encoder_freezing" section is *required*:
"text_encoder_freezing": {
"unfreeze_last_n_layers": 0,
"freeze_embeddings": false,
"freeze_final_layer_norm": true
}
In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method.
"""
class TextualInversionPlugin(BasePlugin):
def __init__(self):
path = os.path.join(os.path.dirname(__file__), "textual_inversion.json")
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
with open(path, 'rt') as f:
self.config = json.load(f)
self.this_batch_tokens = None
self.training_tokens = None
self.training_token_ids = None
self.original_text_embeddings = None
def on_model_load(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs.get('ed_state')
optimizer_config: dict = kwargs.get('optimizer_config')
def get_token_ids(t: str):
return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t))
# check for correctly configured text encoder training
num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers)
if (optimizer_config is None or
'text_encoder_freezing' not in optimizer_config or
optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or
optimizer_config['text_encoder_freezing'].get('freeze_final_layer_norm') != True or
optimizer_config['text_encoder_freezing'].get('unfreeze_last_n_layers', num_te_layers) > 0
):
required_js_fragment = {"text_encoder_freezing": {"freeze_embeddings": False, "unfreeze_last_n_layers": 0, "freeze_final_layer_norm": True}}
logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES the following json fragment in your optimizer config:{Fore.RESET}")
logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}")
raise RuntimeError("Misconfigured optimizer config")
tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1]
logging.info(
f" * Textual inversion training adding the following tokens: {tokens_to_add}")
tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add]
if any(tokens_to_overwrite):
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}")
num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add)
if num_added_tokens != len(tokens_to_add):
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer))
added_token_ids = []
input_embeddings = ed_state.text_encoder.get_input_embeddings()
for token_info in self.config['tokens']:
# get newly added token id
t = token_info['token']
token_ids = get_token_ids(t)
if len(token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}")
token_id = token_ids[0]
added_token_ids.append(token_id)
# copy initializer embedding
initializer_word = token_info['initializer_word']
initializer_word_token_ids = get_token_ids(initializer_word)
if len(initializer_word_token_ids) != 1:
raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs "
f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.")
initializer_word_token_id = initializer_word_token_ids[0]
initializer_embedding = input_embeddings.weight.data[initializer_word_token_id]
input_embeddings.weight.data[token_id] = initializer_embedding
overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite]
self.training_tokens = tokens_to_add + tokens_to_overwrite
self.training_token_ids = added_token_ids + overwriting_token_ids
self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone()
def on_step_start(self, **kwargs):
batch = kwargs['batch']
tokens = batch['tokens'] # a torch.stack
self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist()
def on_step_end(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
# reset the embeddings that have been touched this step, except the ones we're training, to their original state
with (torch.no_grad()):
embeddings = ed_state.text_encoder.get_input_embeddings()
embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids]
for t in embeddings_to_restore:
embeddings.weight[t] = self.original_text_embeddings[t]
def on_model_save(self, **kwargs):
ed_state: EveryDreamTrainingState = kwargs['ed_state']
embeddings = ed_state.text_encoder.get_input_embeddings()
save_folder = kwargs['save_folder']
for token_id, token in zip(self.training_token_ids, self.training_tokens):
_save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder)
def _save_embedding(token, embedding, save_folder):
dict_to_save = {token: embedding}
token_name_safe = clean_filename(token)
ti_folder = os.path.join(save_folder, 'textual_inversions')
os.makedirs(ti_folder, exist_ok=True)
save_path = os.path.join(ti_folder, token_name_safe + '.bin')
logging.info(f"Saving textual inversion for '{token}' to {save_path}")
torch.save(dict_to_save, save_path)

View File

@ -1096,13 +1096,21 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
def update_arg(arg: str, newValue):
if arg == "grad_accum":
args.grad_accum = newValue
data_loader.grad_accum = newValue
else:
raise("Unrecognized arg: " + arg)
plugin_runner.run_on_epoch_start(
epoch=epoch,
global_step=global_step,
epoch_length=epoch_len,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root
data_root=args.data_root,
arg_update_callback=update_arg
)
@ -1229,6 +1237,13 @@ def main(args):
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
plugin_runner.run_on_epoch_end(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root,
update_arg_callback=update_arg)
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
@ -1238,12 +1253,6 @@ def main(args):
loss_epoch = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
plugin_runner.run_on_epoch_end(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root)
gc.collect()
# end of epoch