From 29396ec21b19be53d26c943c7ff012a128916377 Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 17:32:54 +0100 Subject: [PATCH] update EveryDreamValidator for noprompt's changes --- data/data_loader.py | 18 ---- data/every_dream.py | 8 +- data/every_dream_validation.py | 164 +++++++++++++++------------------ train.py | 150 ++++++++++++++---------------- 4 files changed, 148 insertions(+), 192 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 9977104..5fe4ba6 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -169,24 +169,6 @@ class DataLoaderMultiAspect(): return picked_images - def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]: - item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size - # sort first, then shuffle, to ensure determinate outcome for the current random state - items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname)) - random.shuffle(items_copy) - split_items = items_copy[:item_count] - if remove_from_dataset: - self.delete_items(split_items) - return split_items - - def delete_items(self, items: list[ImageTrainItem]): - for item in items: - for i, other_item in enumerate(self.prepared_train_data): - if other_item.pathname == item.pathname: - self.prepared_train_data.pop(i) - break - self.__update_rating_sums() - def __update_rating_sums(self): self.rating_overall_sum: float = 0.0 self.ratings_summed: list[float] = [] diff --git a/data/every_dream.py b/data/every_dream.py index 00aeffb..2d93a63 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset): def __update_image_train_items(self, dropout_fraction: float): self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) -def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader: +def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader: dataloader = torch.utils.data.DataLoader( - items, + dataset, batch_size=batch_size, shuffle=False, - num_workers=0, + num_workers=4, collate_fn=collate_fn ) return dataloader @@ -173,4 +173,4 @@ def collate_fn(batch): "runt_size": runt_size, } del batch - return ret \ No newline at end of file + return ret diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 5d4c7fe..ddcb52c 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -1,4 +1,5 @@ import json +import math import random from typing import Callable, Any, Optional from argparse import Namespace @@ -14,57 +15,67 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch from data.data_loader import DataLoaderMultiAspect from data import resolver from data import aspects +from data.image_train_item import ImageTrainItem from utils.isolate_rng import isolate_rng +def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \ + -> tuple[list[ImageTrainItem], list[ImageTrainItem]]: + split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size + # sort first, then shuffle, to ensure determinate outcome for the current random state + items_copy = list(sorted(items, key=lambda i: i.pathname)) + random.shuffle(items_copy) + split_items = list(items_copy[:split_item_count]) + remaining_items = list(items_copy[split_item_count:]) + return split_items, remaining_items + + class EveryDreamValidator: def __init__(self, val_config_path: Optional[str], - train_batch: EveryDreamBatch, + default_batch_size: int, log_writer: SummaryWriter): + self.val_dataloader = None + self.train_overlapping_dataloader = None + self.log_writer = log_writer - val_config = {} + self.config = {} if val_config_path is not None: with open(val_config_path, 'rt') as f: - val_config = json.load(f) + self.config = json.load(f) - do_validation = val_config.get('validate_training', False) - val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none' - self.val_data_root = val_config.get('val_data_root', None) - val_split_proportion = val_config.get('val_split_proportion', 0.15) - - stabilize_training_loss = val_config.get('stabilize_training_loss', False) - stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15) - - self.every_n_epochs = val_config.get('every_n_epochs', 1) - self.seed = val_config.get('seed', 555) + self.batch_size = self.config.get('batch_size', default_batch_size) + self.every_n_epochs = self.config.get('every_n_epochs', 1) + self.seed = self.config.get('seed', 555) + self.val_data_root = self.config.get('val_data_root', None) + def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]: + """ + Build the validation splits as requested by the config passed at init. + This may steal some items from `train_items`. + If this happens, the returned `list` contains the remaining items after the required items have been stolen. + Otherwise, the returned `list` is identical to the passed-in `train_items`. + """ with isolate_rng(): - self.val_dataloader = self._build_validation_dataloader(val_split_mode, - split_proportion=val_split_proportion, - val_data_root=self.val_data_root, - train_batch=train_batch) + self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer) # order is important - if we're removing images from train, this needs to happen before making # the overlapping dataloader - self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch, - split_proportion=stabilize_split_proportion, - name='train-stabilizer', - enforce_split=False) if stabilize_training_loss else None - + self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer) + return remaining_train_items def do_validation_if_appropriate(self, epoch: int, global_step: int, get_model_prediction_and_target_callable: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]): if (epoch % self.every_n_epochs) == 0: if self.train_overlapping_dataloader is not None: - self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable) + self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, + get_model_prediction_and_target_callable) if self.val_dataloader is not None: self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable) - def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[ - [Any, Any], tuple[torch.Tensor, torch.Tensor]]): + [Any, Any], tuple[torch.Tensor, torch.Tensor]]): with torch.no_grad(), isolate_rng(): loss_validation_epoch = [] steps_pbar = tqdm(range(len(dataloader)), position=1) @@ -75,8 +86,6 @@ class EveryDreamValidator: torch.manual_seed(self.seed + step) model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) - # del timesteps, encoder_hidden_states, noisy_latents - # with autocast(enabled=args.amp): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") del target, model_pred @@ -91,80 +100,55 @@ class EveryDreamValidator: loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) - - def _build_validation_dataloader(self, - validation_split_mode: str, - split_proportion: float, - val_data_root: Optional[str], - train_batch: EveryDreamBatch) -> Optional[DataLoader]: - if validation_split_mode == 'none': - return None - elif validation_split_mode == 'automatic': - return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True) - elif validation_split_mode == 'manual': - if val_data_root is None: - raise ValueError("val_data_root is required for 'manual' validation split mode") - return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch) + def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\ + -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: + val_split_mode = self.config.get('val_split_mode', 'automatic') + val_split_proportion = self.config.get('val_split_proportion', 0.15) + remaining_train_items = image_train_items + if val_split_mode == 'none': + return None, image_train_items + elif val_split_mode == 'automatic': + val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size) + elif val_split_mode == 'manual': + args = Namespace( + aspects=aspects.get_aspect_buckets(512), + flip_p=0.0, + seed=self.seed, + ) + val_items = resolver.resolve_root(self.val_data_root, args) else: - raise ValueError(f"unhandled validation split mode '{validation_split_mode}'") + raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") + val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val') + val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size) + return val_dataloader, remaining_train_items + def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \ + -> Optional[torch.utils.data.DataLoader]: + stabilize_training_loss = self.config.get('stabilize_training_loss', False) + if not stabilize_training_loss: + return None - def _build_dataloader_from_automatic_split(self, - train_batch: EveryDreamBatch, - split_proportion: float, - name: str, - enforce_split: bool=False) -> DataLoader: - """ - Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove - the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader. - """ - with isolate_rng(): - random.seed(self.seed) - val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split) - if enforce_split: - print( - f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left") - if len(train_batch) == 0: - raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}") - val_batch = self._make_val_batch_with_train_batch_settings( - val_items, - train_batch, - name=name - ) - return build_torch_dataloader( - items=val_batch, - batch_size=train_batch.batch_size, - ) + stabilize_split_proportion = self.config.get('stabilize_split_proportion', 0.15) + stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) + stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer, + name='stabilize-train') + stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) + return stabilize_dataloader - - def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader: - val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch) - return build_torch_dataloader( - items=val_batch, - batch_size=reference_train_batch.batch_size - ) - - def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'): - batch_size = reference_train_batch.batch_size - seed = reference_train_batch.seed - args = Namespace( - aspects=aspects.get_aspect_buckets(512), - flip_p=0.0, - seed=seed, - ) - image_train_items = resolver.resolve(data_root, args) + def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'): + batch_size = self.batch_size + seed = self.seed data_loader = DataLoaderMultiAspect( - image_train_items, + items, batch_size=batch_size, seed=seed, ) - val_batch = EveryDreamBatch( + ed_batch = EveryDreamBatch( data_loader=data_loader, debug_level=1, - batch_size=batch_size, conditional_dropout=0, - tokenizer=reference_train_batch.tokenizer, + tokenizer=tokenizer, seed=seed, name=name, ) - return val_batch \ No newline at end of file + return ed_batch diff --git a/train.py b/train.py index 48a1aaf..737e5b6 100644 --- a/train.py +++ b/train.py @@ -52,7 +52,8 @@ import wandb from torch.utils.tensorboard import SummaryWriter from data.data_loader import DataLoaderMultiAspect -from data.every_dream import EveryDreamBatch +from data.every_dream import EveryDreamBatch, build_torch_dataloader +from data.every_dream_validation import EveryDreamValidator from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter @@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str): return sample_prompts - -def collate_fn(batch): - """ - Collates batches - """ - images = [example["image"] for example in batch] - captions = [example["caption"] for example in batch] - tokens = [example["tokens"] for example in batch] - runt_size = batch[0]["runt_size"] - - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - - ret = { - "tokens": torch.stack(tuple(tokens)), - "image": images, - "captions": captions, - "runt_size": runt_size, - } - del batch - return ret - - def main(args): """ Main entry point @@ -387,10 +365,13 @@ def main(args): seed = args.seed if args.seed != -1 else random.randint(0, 2**30) logging.info(f" Seed: {seed}") set_seed(seed) - gpu = GPU() - device = torch.device(f"cuda:{args.gpuid}") - - torch.backends.cudnn.benchmark = True + if torch.cuda.is_available(): + gpu = GPU() + device = torch.device(f"cuda:{args.gpuid}") + torch.backends.cudnn.benchmark = True + else: + logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.") + device = 'cpu' log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") @@ -606,6 +587,11 @@ def main(args): logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) + log_writer = SummaryWriter(log_dir=log_folder, + flush_secs=5, + comment="EveryDream2FineTunes", + ) + betas = (0.9, 0.999) epsilon = 1e-8 if args.amp: @@ -630,9 +616,14 @@ def main(args): ) log_optimizer(optimizer, betas, epsilon) - + + image_train_items = resolve_image_train_items(args, log_folder) + validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) + # the validation dataset may need to steal some items from image_train_items + image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) + data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, seed=seed, @@ -668,12 +659,7 @@ def main(args): if args.wandb is not None and args.wandb: wandb.init(project=args.project_name, sync_tensorboard=True, ) - - log_writer = SummaryWriter( - log_dir=log_folder, - flush_secs=5, - comment="EveryDream2FineTunes", - ) + def log_args(log_writer, args): arglog = "args:\n" @@ -729,15 +715,7 @@ def main(args): logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") - - train_dataloader = torch.utils.data.DataLoader( - train_batch, - batch_size=args.batch_size, - shuffle=False, - num_workers=4, - collate_fn=collate_fn, - pin_memory=True - ) + train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size) unet.train() if not args.disable_unet_training else unet.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() @@ -775,7 +753,49 @@ def main(args): loss_log_step = [] assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" - + + # actual prediction function - shared between train and validate + def get_model_prediction_and_target(image, tokens): + with torch.no_grad(): + with autocast(enabled=args.amp): + pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device) + latents = vae.encode(pixel_values, return_dict=False) + del pixel_values + latents = latents[0].sample() * 0.18215 + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + cuda_caption = tokens.to(text_encoder.device) + + # with autocast(enabled=args.amp): + encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True) + + if args.clip_skip > 0: + encoder_hidden_states = text_encoder.text_model.final_layer_norm( + encoder_hidden_states.hidden_states[-args.clip_skip]) + else: + encoder_hidden_states = encoder_hidden_states.last_hidden_state + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]: + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + del noise, latents, cuda_caption + + with autocast(enabled=args.amp): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + return model_pred, target + + try: # # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py # pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution])) @@ -809,41 +829,7 @@ def main(args): for step, batch in enumerate(train_dataloader): step_start_time = time.time() - with torch.no_grad(): - with autocast(enabled=args.amp): - pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) - latents = vae.encode(pixel_values, return_dict=False) - del pixel_values - latents = latents[0].sample() * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - cuda_caption = batch["tokens"].to(text_encoder.device) - - #with autocast(enabled=args.amp): - encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True) - - if args.clip_skip > 0: - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip]) - else: - encoder_hidden_states = encoder_hidden_states.last_hidden_state - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]: - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - del noise, latents, cuda_caption - - with autocast(enabled=args.amp): - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) #del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): @@ -952,6 +938,10 @@ def main(args): loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) + + # validate + validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) + gc.collect() # end of epoch