diff --git a/data/every_dream.py b/data/every_dream.py index 133b15a..07463d9 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -58,7 +58,7 @@ class EveryDreamBatch(Dataset): self.rated_dataset = rated_dataset self.rated_dataset_dropout_target = rated_dataset_dropout_target # First epoch always trains on all images - self.image_train_items = [] + self.image_train_items = [] self.__update_image_train_items(1.0) self.name = name diff --git a/plugins/accumulnator.py b/plugins/accumulnator.py new file mode 100644 index 0000000..35b7229 --- /dev/null +++ b/plugins/accumulnator.py @@ -0,0 +1,53 @@ +import json +import logging +import os + +from plugins.plugins import BasePlugin + +class Accumulnator(BasePlugin): + + def __init__(self): + path = os.path.join(os.path.dirname(__file__), "accumulnator.json") + logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}") + with open(path, 'rt') as f: + config = json.load(f) + begin_epoch = config['begin_epoch'] + begin_grad_accum = config['begin_grad_accum'] + end_epoch = config['end_epoch'] + end_grad_accum = config['end_grad_accum'] + accums_per_epoch = {} + for i in range(begin_epoch): + accums_per_epoch[i] = begin_grad_accum + grad_accum_step = (end_grad_accum-begin_grad_accum)/(end_epoch-begin_epoch) + for i in range(end_grad_accum-begin_grad_accum): + grad_accum = round(grad_accum_step * i) + accums_per_epoch[i+begin_epoch] = grad_accum + self.per_epoch_grad_accum = accums_per_epoch + + + def on_epoch_end(self, **kwargs): + just_finished_epoch = kwargs['epoch'] + epoch = just_finished_epoch + 1 + grad_accum = self.per_epoch_grad_accum.get(epoch) + if grad_accum is None: + logging.warning(f" * Acculmunator has no grad_accum setting for epoch {epoch} - leaving as-is") + else: + logging.info(f" * Acculmunator setting grad_accum for epoch {epoch} to {grad_accum}") + arg_update_callback = kwargs['arg_update_callback'] + arg_update_callback('grad_accum', grad_accum) + + + def _get_update_step_indices(self, epoch, epoch_length_steps: int) -> list[int]: + if self.every_n_epochs >= 1: + if ((epoch+1) % self.every_n_epochs) == 0: + # last step only + return [epoch_length_steps-1] + else: + return [] + else: + # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps + num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) + # validation happens after training: + # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 + validate_every_n_steps = epoch_length_steps / num_divisions + return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)] diff --git a/plugins/textual_inversion.py b/plugins/textual_inversion.py new file mode 100644 index 0000000..147750f --- /dev/null +++ b/plugins/textual_inversion.py @@ -0,0 +1,134 @@ +import json +import logging +import os.path + +import torch +from colorama import Fore + +from plugins.plugins import BasePlugin +from train import EveryDreamTrainingState +from utils.sample_generator import clean_filename + +""" +This plugin adds custom tokens to the tokenizer and trains just these tokens, with the rest of the text encoder +disabled/frozen. + +token/initialization config is in textual_inversion.json, same folder as this .py file. + +For pure Textual Inversion training: + "disable_textenc_training": false, + "disable_unet_training": true +(Or you could unet training on too if you want, I didn't test this.) + + +In optimizer.json, the following "text_encoder_freezing" section is *required*: + "text_encoder_freezing": { + "unfreeze_last_n_layers": 0, + "freeze_embeddings": false, + "freeze_final_layer_norm": true + } +In addition, you'll need a very high LR on the TE - maybe even as high as 1e-3. I recommend using the LR finder method. + +""" + +class TextualInversionPlugin(BasePlugin): + + def __init__(self): + path = os.path.join(os.path.dirname(__file__), "textual_inversion.json") + logging.info(f" * Textual Inversion plugin instantiated, loading config from {path}") + with open(path, 'rt') as f: + self.config = json.load(f) + self.this_batch_tokens = None + self.training_tokens = None + self.training_token_ids = None + self.original_text_embeddings = None + + def on_model_load(self, **kwargs): + ed_state: EveryDreamTrainingState = kwargs.get('ed_state') + optimizer_config: dict = kwargs.get('optimizer_config') + def get_token_ids(t: str): + return ed_state.tokenizer.convert_tokens_to_ids(ed_state.tokenizer.tokenize(t)) + + # check for correctly configured text encoder training + num_te_layers = len(ed_state.text_encoder.text_model.encoder.layers) + if (optimizer_config is None or + 'text_encoder_freezing' not in optimizer_config or + optimizer_config['text_encoder_freezing'].get('freeze_embeddings') != False or + optimizer_config['text_encoder_freezing'].get('freeze_final_layer_norm') != True or + optimizer_config['text_encoder_freezing'].get('unfreeze_last_n_layers', num_te_layers) > 0 + ): + required_js_fragment = {"text_encoder_freezing": {"freeze_embeddings": False, "unfreeze_last_n_layers": 0, "freeze_final_layer_norm": True}} + logging.error(f" * {Fore.LIGHTRED_EX}Textual Inversion plugin REQUIRES the following json fragment in your optimizer config:{Fore.RESET}") + logging.error(f" * {Fore.LIGHTRED_EX} {json.dumps(required_js_fragment)}{Fore.RESET}") + raise RuntimeError("Misconfigured optimizer config") + + tokens_to_add = [t['token'] for t in self.config['tokens'] if len(get_token_ids(t['token']))>1] + logging.info( + f" * Textual inversion training adding the following tokens: {tokens_to_add}") + tokens_to_overwrite = [t['token'] for t in self.config['tokens'] if t['token'] not in tokens_to_add] + if any(tokens_to_overwrite): + logging.warning(f" * {Fore.LIGHTYELLOW_EX}Textual inversion training overwriting the following tokens: {tokens_to_overwrite}{Fore.RESET}") + + num_added_tokens = ed_state.tokenizer.add_tokens(tokens_to_add) + if num_added_tokens != len(tokens_to_add): + raise RuntimeError(f"Tokens not added successfully - tried to add {len(tokens_to_add)} but only added {num_added_tokens}") + ed_state.text_encoder.resize_token_embeddings(len(ed_state.tokenizer)) + + added_token_ids = [] + input_embeddings = ed_state.text_encoder.get_input_embeddings() + for token_info in self.config['tokens']: + # get newly added token id + t = token_info['token'] + token_ids = get_token_ids(t) + if len(token_ids) != 1: + raise RuntimeError(f"Tokens not added succesfully - expected 1 token id for {t}, found {len(token_ids)}") + token_id = token_ids[0] + added_token_ids.append(token_id) + + # copy initializer embedding + initializer_word = token_info['initializer_word'] + initializer_word_token_ids = get_token_ids(initializer_word) + if len(initializer_word_token_ids) != 1: + raise RuntimeError(f"Tokens not added succesfully - initializer word '{initializer_word}' needs " + f"{len(initializer_word_token_ids)} tokens, but only single tokens are supported.") + initializer_word_token_id = initializer_word_token_ids[0] + initializer_embedding = input_embeddings.weight.data[initializer_word_token_id] + input_embeddings.weight.data[token_id] = initializer_embedding + + overwriting_token_ids = [get_token_ids(t)[0] for t in tokens_to_overwrite] + self.training_tokens = tokens_to_add + tokens_to_overwrite + self.training_token_ids = added_token_ids + overwriting_token_ids + self.original_text_embeddings = ed_state.text_encoder.get_input_embeddings().weight.data.detach().clone() + + + def on_step_start(self, **kwargs): + batch = kwargs['batch'] + tokens = batch['tokens'] # a torch.stack + self.this_batch_tokens = torch.unique(torch.flatten(tokens)).tolist() + + def on_step_end(self, **kwargs): + ed_state: EveryDreamTrainingState = kwargs['ed_state'] + + # reset the embeddings that have been touched this step, except the ones we're training, to their original state + with (torch.no_grad()): + embeddings = ed_state.text_encoder.get_input_embeddings() + embeddings_to_restore = [t for t in self.this_batch_tokens if t not in self.training_token_ids] + for t in embeddings_to_restore: + embeddings.weight[t] = self.original_text_embeddings[t] + + def on_model_save(self, **kwargs): + ed_state: EveryDreamTrainingState = kwargs['ed_state'] + embeddings = ed_state.text_encoder.get_input_embeddings() + save_folder = kwargs['save_folder'] + for token_id, token in zip(self.training_token_ids, self.training_tokens): + _save_embedding(token=token, embedding=embeddings.weight[token_id], save_folder=save_folder) + +def _save_embedding(token, embedding, save_folder): + dict_to_save = {token: embedding} + token_name_safe = clean_filename(token) + ti_folder = os.path.join(save_folder, 'textual_inversions') + os.makedirs(ti_folder, exist_ok=True) + save_path = os.path.join(ti_folder, token_name_safe + '.bin') + logging.info(f"Saving textual inversion for '{token}' to {save_path}") + torch.save(dict_to_save, save_path) + diff --git a/train.py b/train.py index 58a2e3f..af7a891 100644 --- a/train.py +++ b/train.py @@ -1096,13 +1096,21 @@ def main(args): epoch_len = math.ceil(len(train_batch) / args.batch_size) + def update_arg(arg: str, newValue): + if arg == "grad_accum": + args.grad_accum = newValue + data_loader.grad_accum = newValue + else: + raise("Unrecognized arg: " + arg) + plugin_runner.run_on_epoch_start( epoch=epoch, global_step=global_step, epoch_length=epoch_len, project_name=args.project_name, log_folder=log_folder, - data_root=args.data_root + data_root=args.data_root, + arg_update_callback=update_arg ) @@ -1229,6 +1237,13 @@ def main(args): epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time)) log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step) + plugin_runner.run_on_epoch_end(epoch=epoch, + global_step=global_step, + project_name=args.project_name, + log_folder=log_folder, + data_root=args.data_root, + update_arg_callback=update_arg) + epoch_pbar.update(1) if epoch < args.max_epochs - 1: train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) @@ -1238,12 +1253,6 @@ def main(args): loss_epoch = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step) - plugin_runner.run_on_epoch_end(epoch=epoch, - global_step=global_step, - project_name=args.project_name, - log_folder=log_folder, - data_root=args.data_root) - gc.collect() # end of epoch