initial implementation of the_acculmunator
This commit is contained in:
parent
ada6037463
commit
26a1475f0c
|
@ -0,0 +1,53 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from plugins.plugins import BasePlugin
|
||||
|
||||
class Accumulnator(BasePlugin):
|
||||
|
||||
def __init__(self):
|
||||
path = os.path.join(os.path.dirname(__file__), "accumulnator.json")
|
||||
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
|
||||
with open(path, 'rt') as f:
|
||||
config = json.load(f)
|
||||
begin_epoch = config['begin_epoch']
|
||||
begin_grad_accum = config['begin_grad_accum']
|
||||
end_epoch = config['end_epoch']
|
||||
end_grad_accum = config['end_grad_accum']
|
||||
accums_per_epoch = {}
|
||||
for i in range(begin_epoch):
|
||||
accums_per_epoch[i] = begin_grad_accum
|
||||
grad_accum_step = (end_grad_accum-begin_grad_accum)/(end_epoch-begin_epoch)
|
||||
for i in range(end_grad_accum-begin_grad_accum):
|
||||
grad_accum = round(grad_accum_step * i)
|
||||
accums_per_epoch[i+begin_epoch] = grad_accum
|
||||
self.per_epoch_grad_accum = accums_per_epoch
|
||||
|
||||
|
||||
def on_epoch_end(self, **kwargs):
|
||||
just_finished_epoch = kwargs['epoch']
|
||||
epoch = just_finished_epoch + 1
|
||||
grad_accum = self.per_epoch_grad_accum.get(epoch)
|
||||
if grad_accum is None:
|
||||
logging.warning(f" * Acculmunator has no grad_accum setting for epoch {epoch} - leaving as-is")
|
||||
else:
|
||||
logging.info(f" * Acculmunator setting grad_accum for epoch {epoch} to {grad_accum}")
|
||||
arg_update_callback = kwargs['arg_update_callback']
|
||||
arg_update_callback('grad_accum', grad_accum)
|
||||
|
||||
|
||||
def _get_update_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||
if self.every_n_epochs >= 1:
|
||||
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||
# validation happens after training:
|
||||
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
|
@ -0,0 +1,134 @@
|
|||
import json
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
import torch
|
||||
from colorama import Fore
|
||||
|
||||
from plugins.plugins import BasePlugin
|
||||
from train import EveryDreamTrainingState
|
||||
from utils.sample_generator import clean_filename
|
||||
|
||||
"""
|
||||
This plugin adds custom tokens to the tokenizer and trains just these tokens, with the rest of the text encoder
|
||||
disabled/frozen.
|
||||
|
||||
token/initialization config is in textual_inversion.json, same folder as this .py file.
|
||||
|
||||
For pure Textual Inversion training:
|
||||
"disable_textenc_training": false,
|
||||
"disable_unet_training": true
|
||||
(Or you could unet training on too if you want, I didn't test this.)
|
||||
|
||||
|
||||
In optimizer.json, the following "text_encoder_freezing" section is *required*:
|
||||
"text_encoder_freezing": {
|
||||
"unfreeze_last_n_layers": 0,
|
||||
"freeze_embeddings": false,
|
||||
"freeze_final_layer_norm": true
|
||||
}
|
||||
In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method.
|
||||
|
||||
"""
|
||||
|
||||
class TextualInversionPlugin(BasePlugin):
|
||||
|
||||
def __init__(self):
|
||||
path = os.path.join(os.path.dirname(__file__), "textual_inversion.json")
|
||||
logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}")
|
||||
with open(path, 'rt') as f:
|
||||
self.config = json.load(f)
|
||||
self.this_batch_tokens = None
|
||||
self.training_tokens = None
|
||||
self.training_token_ids = None
|
||||
self.original_text_embeddings = None
|
||||
|
||||
def on_model_load(self, **kwargs):
|
||||
ed_state: EveryDreamTrainingState = kwargs.get('ed_state')
|
||||
optimizer_config: dict = kwargs.get('optimizer_config')
|
||||
def get_token_ids(t: str):
|
||||
return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t))
|
||||
|
||||
# check for correctly configured text encoder training
|
||||
num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers)
|
||||
if (optimizer_config is None or
|
||||
'text_encoder_freezing' not in optimizer_config or
|
||||
optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or
|
||||
optimizer_config['text_encoder_freezing'].get('freeze_final_layer_norm') != True or
|
||||
optimizer_config['text_encoder_freezing'].get('unfreeze_last_n_layers', num_te_layers) > 0
|
||||
):
|
||||
required_js_fragment = {"text_encoder_freezing": {"freeze_embeddings": False, "unfreeze_last_n_layers": 0, "freeze_final_layer_norm": True}}
|
||||
logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES the following json fragment in your optimizer config:{Fore.RESET}")
|
||||
logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}")
|
||||
raise RuntimeError("Misconfigured optimizer config")
|
||||
|
||||
tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1]
|
||||
logging.info(
|
||||
f" * Textual inversion training adding the following tokens: {tokens_to_add}")
|
||||
tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add]
|
||||
if any(tokens_to_overwrite):
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}")
|
||||
|
||||
num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add)
|
||||
if num_added_tokens != len(tokens_to_add):
|
||||
raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}")
|
||||
ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer))
|
||||
|
||||
added_token_ids = []
|
||||
input_embeddings = ed_state.text_encoder.get_input_embeddings()
|
||||
for token_info in self.config['tokens']:
|
||||
# get newly added token id
|
||||
t = token_info['token']
|
||||
token_ids = get_token_ids(t)
|
||||
if len(token_ids) != 1:
|
||||
raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}")
|
||||
token_id = token_ids[0]
|
||||
added_token_ids.append(token_id)
|
||||
|
||||
# copy initializer embedding
|
||||
initializer_word = token_info['initializer_word']
|
||||
initializer_word_token_ids = get_token_ids(initializer_word)
|
||||
if len(initializer_word_token_ids) != 1:
|
||||
raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs "
|
||||
f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.")
|
||||
initializer_word_token_id = initializer_word_token_ids[0]
|
||||
initializer_embedding = input_embeddings.weight.data[initializer_word_token_id]
|
||||
input_embeddings.weight.data[token_id] = initializer_embedding
|
||||
|
||||
overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite]
|
||||
self.training_tokens = tokens_to_add + tokens_to_overwrite
|
||||
self.training_token_ids = added_token_ids + overwriting_token_ids
|
||||
self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
|
||||
def on_step_start(self, **kwargs):
|
||||
batch = kwargs['batch']
|
||||
tokens = batch['tokens'] # a torch.stack
|
||||
self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist()
|
||||
|
||||
def on_step_end(self, **kwargs):
|
||||
ed_state: EveryDreamTrainingState = kwargs['ed_state']
|
||||
|
||||
# reset the embeddings that have been touched this step, except the ones we're training, to their original state
|
||||
with (torch.no_grad()):
|
||||
embeddings = ed_state.text_encoder.get_input_embeddings()
|
||||
embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids]
|
||||
for t in embeddings_to_restore:
|
||||
embeddings.weight[t] = self.original_text_embeddings[t]
|
||||
|
||||
def on_model_save(self, **kwargs):
|
||||
ed_state: EveryDreamTrainingState = kwargs['ed_state']
|
||||
embeddings = ed_state.text_encoder.get_input_embeddings()
|
||||
save_folder = kwargs['save_folder']
|
||||
for token_id, token in zip(self.training_token_ids, self.training_tokens):
|
||||
_save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder)
|
||||
|
||||
def _save_embedding(token, embedding, save_folder):
|
||||
dict_to_save = {token: embedding}
|
||||
token_name_safe = clean_filename(token)
|
||||
ti_folder = os.path.join(save_folder, 'textual_inversions')
|
||||
os.makedirs(ti_folder, exist_ok=True)
|
||||
save_path = os.path.join(ti_folder, token_name_safe + '.bin')
|
||||
logging.info(f"Saving textual inversion for '{token}' to {save_path}")
|
||||
torch.save(dict_to_save, save_path)
|
||||
|
23
train.py
23
train.py
|
@ -1096,13 +1096,21 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
def update_arg(arg: str, newValue):
|
||||
if arg == "grad_accum":
|
||||
args.grad_accum = newValue
|
||||
data_loader.grad_accum = newValue
|
||||
else:
|
||||
raise("Unrecognized arg: " + arg)
|
||||
|
||||
plugin_runner.run_on_epoch_start(
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
epoch_length=epoch_len,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root
|
||||
data_root=args.data_root,
|
||||
arg_update_callback=update_arg
|
||||
)
|
||||
|
||||
|
||||
|
@ -1229,6 +1237,13 @@ def main(args):
|
|||
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
|
||||
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
|
||||
|
||||
plugin_runner.run_on_epoch_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root,
|
||||
update_arg_callback=update_arg)
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if epoch < args.max_epochs - 1:
|
||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
|
@ -1238,12 +1253,6 @@ def main(args):
|
|||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
||||
|
||||
plugin_runner.run_on_epoch_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
|
Loading…
Reference in New Issue