From 29396ec21b19be53d26c943c7ff012a128916377 Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 17:32:54 +0100 Subject: [PATCH 1/6] update EveryDreamValidator for noprompt's changes --- data/data_loader.py | 18 ---- data/every_dream.py | 8 +- data/every_dream_validation.py | 164 +++++++++++++++------------------ train.py | 150 ++++++++++++++---------------- 4 files changed, 148 insertions(+), 192 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 9977104..5fe4ba6 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -169,24 +169,6 @@ class DataLoaderMultiAspect(): return picked_images - def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]: - item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size - # sort first, then shuffle, to ensure determinate outcome for the current random state - items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname)) - random.shuffle(items_copy) - split_items = items_copy[:item_count] - if remove_from_dataset: - self.delete_items(split_items) - return split_items - - def delete_items(self, items: list[ImageTrainItem]): - for item in items: - for i, other_item in enumerate(self.prepared_train_data): - if other_item.pathname == item.pathname: - self.prepared_train_data.pop(i) - break - self.__update_rating_sums() - def __update_rating_sums(self): self.rating_overall_sum: float = 0.0 self.ratings_summed: list[float] = [] diff --git a/data/every_dream.py b/data/every_dream.py index 00aeffb..2d93a63 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset): def __update_image_train_items(self, dropout_fraction: float): self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) -def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader: +def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader: dataloader = torch.utils.data.DataLoader( - items, + dataset, batch_size=batch_size, shuffle=False, - num_workers=0, + num_workers=4, collate_fn=collate_fn ) return dataloader @@ -173,4 +173,4 @@ def collate_fn(batch): "runt_size": runt_size, } del batch - return ret \ No newline at end of file + return ret diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 5d4c7fe..ddcb52c 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -1,4 +1,5 @@ import json +import math import random from typing import Callable, Any, Optional from argparse import Namespace @@ -14,57 +15,67 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch from data.data_loader import DataLoaderMultiAspect from data import resolver from data import aspects +from data.image_train_item import ImageTrainItem from utils.isolate_rng import isolate_rng +def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \ + -> tuple[list[ImageTrainItem], list[ImageTrainItem]]: + split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size + # sort first, then shuffle, to ensure determinate outcome for the current random state + items_copy = list(sorted(items, key=lambda i: i.pathname)) + random.shuffle(items_copy) + split_items = list(items_copy[:split_item_count]) + remaining_items = list(items_copy[split_item_count:]) + return split_items, remaining_items + + class EveryDreamValidator: def __init__(self, val_config_path: Optional[str], - train_batch: EveryDreamBatch, + default_batch_size: int, log_writer: SummaryWriter): + self.val_dataloader = None + self.train_overlapping_dataloader = None + self.log_writer = log_writer - val_config = {} + self.config = {} if val_config_path is not None: with open(val_config_path, 'rt') as f: - val_config = json.load(f) + self.config = json.load(f) - do_validation = val_config.get('validate_training', False) - val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none' - self.val_data_root = val_config.get('val_data_root', None) - val_split_proportion = val_config.get('val_split_proportion', 0.15) - - stabilize_training_loss = val_config.get('stabilize_training_loss', False) - stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15) - - self.every_n_epochs = val_config.get('every_n_epochs', 1) - self.seed = val_config.get('seed', 555) + self.batch_size = self.config.get('batch_size', default_batch_size) + self.every_n_epochs = self.config.get('every_n_epochs', 1) + self.seed = self.config.get('seed', 555) + self.val_data_root = self.config.get('val_data_root', None) + def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]: + """ + Build the validation splits as requested by the config passed at init. + This may steal some items from `train_items`. + If this happens, the returned `list` contains the remaining items after the required items have been stolen. + Otherwise, the returned `list` is identical to the passed-in `train_items`. + """ with isolate_rng(): - self.val_dataloader = self._build_validation_dataloader(val_split_mode, - split_proportion=val_split_proportion, - val_data_root=self.val_data_root, - train_batch=train_batch) + self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer) # order is important - if we're removing images from train, this needs to happen before making # the overlapping dataloader - self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch, - split_proportion=stabilize_split_proportion, - name='train-stabilizer', - enforce_split=False) if stabilize_training_loss else None - + self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer) + return remaining_train_items def do_validation_if_appropriate(self, epoch: int, global_step: int, get_model_prediction_and_target_callable: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]): if (epoch % self.every_n_epochs) == 0: if self.train_overlapping_dataloader is not None: - self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable) + self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, + get_model_prediction_and_target_callable) if self.val_dataloader is not None: self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable) - def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[ - [Any, Any], tuple[torch.Tensor, torch.Tensor]]): + [Any, Any], tuple[torch.Tensor, torch.Tensor]]): with torch.no_grad(), isolate_rng(): loss_validation_epoch = [] steps_pbar = tqdm(range(len(dataloader)), position=1) @@ -75,8 +86,6 @@ class EveryDreamValidator: torch.manual_seed(self.seed + step) model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) - # del timesteps, encoder_hidden_states, noisy_latents - # with autocast(enabled=args.amp): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") del target, model_pred @@ -91,80 +100,55 @@ class EveryDreamValidator: loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) - - def _build_validation_dataloader(self, - validation_split_mode: str, - split_proportion: float, - val_data_root: Optional[str], - train_batch: EveryDreamBatch) -> Optional[DataLoader]: - if validation_split_mode == 'none': - return None - elif validation_split_mode == 'automatic': - return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True) - elif validation_split_mode == 'manual': - if val_data_root is None: - raise ValueError("val_data_root is required for 'manual' validation split mode") - return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch) + def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\ + -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: + val_split_mode = self.config.get('val_split_mode', 'automatic') + val_split_proportion = self.config.get('val_split_proportion', 0.15) + remaining_train_items = image_train_items + if val_split_mode == 'none': + return None, image_train_items + elif val_split_mode == 'automatic': + val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size) + elif val_split_mode == 'manual': + args = Namespace( + aspects=aspects.get_aspect_buckets(512), + flip_p=0.0, + seed=self.seed, + ) + val_items = resolver.resolve_root(self.val_data_root, args) else: - raise ValueError(f"unhandled validation split mode '{validation_split_mode}'") + raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") + val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val') + val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size) + return val_dataloader, remaining_train_items + def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \ + -> Optional[torch.utils.data.DataLoader]: + stabilize_training_loss = self.config.get('stabilize_training_loss', False) + if not stabilize_training_loss: + return None - def _build_dataloader_from_automatic_split(self, - train_batch: EveryDreamBatch, - split_proportion: float, - name: str, - enforce_split: bool=False) -> DataLoader: - """ - Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove - the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader. - """ - with isolate_rng(): - random.seed(self.seed) - val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split) - if enforce_split: - print( - f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left") - if len(train_batch) == 0: - raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}") - val_batch = self._make_val_batch_with_train_batch_settings( - val_items, - train_batch, - name=name - ) - return build_torch_dataloader( - items=val_batch, - batch_size=train_batch.batch_size, - ) + stabilize_split_proportion = self.config.get('stabilize_split_proportion', 0.15) + stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) + stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer, + name='stabilize-train') + stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) + return stabilize_dataloader - - def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader: - val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch) - return build_torch_dataloader( - items=val_batch, - batch_size=reference_train_batch.batch_size - ) - - def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'): - batch_size = reference_train_batch.batch_size - seed = reference_train_batch.seed - args = Namespace( - aspects=aspects.get_aspect_buckets(512), - flip_p=0.0, - seed=seed, - ) - image_train_items = resolver.resolve(data_root, args) + def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'): + batch_size = self.batch_size + seed = self.seed data_loader = DataLoaderMultiAspect( - image_train_items, + items, batch_size=batch_size, seed=seed, ) - val_batch = EveryDreamBatch( + ed_batch = EveryDreamBatch( data_loader=data_loader, debug_level=1, - batch_size=batch_size, conditional_dropout=0, - tokenizer=reference_train_batch.tokenizer, + tokenizer=tokenizer, seed=seed, name=name, ) - return val_batch \ No newline at end of file + return ed_batch diff --git a/train.py b/train.py index 48a1aaf..737e5b6 100644 --- a/train.py +++ b/train.py @@ -52,7 +52,8 @@ import wandb from torch.utils.tensorboard import SummaryWriter from data.data_loader import DataLoaderMultiAspect -from data.every_dream import EveryDreamBatch +from data.every_dream import EveryDreamBatch, build_torch_dataloader +from data.every_dream_validation import EveryDreamValidator from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter @@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str): return sample_prompts - -def collate_fn(batch): - """ - Collates batches - """ - images = [example["image"] for example in batch] - captions = [example["caption"] for example in batch] - tokens = [example["tokens"] for example in batch] - runt_size = batch[0]["runt_size"] - - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - - ret = { - "tokens": torch.stack(tuple(tokens)), - "image": images, - "captions": captions, - "runt_size": runt_size, - } - del batch - return ret - - def main(args): """ Main entry point @@ -387,10 +365,13 @@ def main(args): seed = args.seed if args.seed != -1 else random.randint(0, 2**30) logging.info(f" Seed: {seed}") set_seed(seed) - gpu = GPU() - device = torch.device(f"cuda:{args.gpuid}") - - torch.backends.cudnn.benchmark = True + if torch.cuda.is_available(): + gpu = GPU() + device = torch.device(f"cuda:{args.gpuid}") + torch.backends.cudnn.benchmark = True + else: + logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.") + device = 'cpu' log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") @@ -606,6 +587,11 @@ def main(args): logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) + log_writer = SummaryWriter(log_dir=log_folder, + flush_secs=5, + comment="EveryDream2FineTunes", + ) + betas = (0.9, 0.999) epsilon = 1e-8 if args.amp: @@ -630,9 +616,14 @@ def main(args): ) log_optimizer(optimizer, betas, epsilon) - + + image_train_items = resolve_image_train_items(args, log_folder) + validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) + # the validation dataset may need to steal some items from image_train_items + image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) + data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, seed=seed, @@ -668,12 +659,7 @@ def main(args): if args.wandb is not None and args.wandb: wandb.init(project=args.project_name, sync_tensorboard=True, ) - - log_writer = SummaryWriter( - log_dir=log_folder, - flush_secs=5, - comment="EveryDream2FineTunes", - ) + def log_args(log_writer, args): arglog = "args:\n" @@ -729,15 +715,7 @@ def main(args): logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") - - train_dataloader = torch.utils.data.DataLoader( - train_batch, - batch_size=args.batch_size, - shuffle=False, - num_workers=4, - collate_fn=collate_fn, - pin_memory=True - ) + train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size) unet.train() if not args.disable_unet_training else unet.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() @@ -775,7 +753,49 @@ def main(args): loss_log_step = [] assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" - + + # actual prediction function - shared between train and validate + def get_model_prediction_and_target(image, tokens): + with torch.no_grad(): + with autocast(enabled=args.amp): + pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device) + latents = vae.encode(pixel_values, return_dict=False) + del pixel_values + latents = latents[0].sample() * 0.18215 + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + cuda_caption = tokens.to(text_encoder.device) + + # with autocast(enabled=args.amp): + encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True) + + if args.clip_skip > 0: + encoder_hidden_states = text_encoder.text_model.final_layer_norm( + encoder_hidden_states.hidden_states[-args.clip_skip]) + else: + encoder_hidden_states = encoder_hidden_states.last_hidden_state + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]: + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + del noise, latents, cuda_caption + + with autocast(enabled=args.amp): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + return model_pred, target + + try: # # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py # pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution])) @@ -809,41 +829,7 @@ def main(args): for step, batch in enumerate(train_dataloader): step_start_time = time.time() - with torch.no_grad(): - with autocast(enabled=args.amp): - pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) - latents = vae.encode(pixel_values, return_dict=False) - del pixel_values - latents = latents[0].sample() * 0.18215 - - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - cuda_caption = batch["tokens"].to(text_encoder.device) - - #with autocast(enabled=args.amp): - encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True) - - if args.clip_skip > 0: - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip]) - else: - encoder_hidden_states = encoder_hidden_states.last_hidden_state - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]: - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - del noise, latents, cuda_caption - - with autocast(enabled=args.amp): - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) #del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): @@ -952,6 +938,10 @@ def main(args): loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) + + # validate + validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) + gc.collect() # end of epoch From c3d844a1bc4f352d38604268d549567c01277819 Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 17:52:23 +0100 Subject: [PATCH 2/6] better config handling --- data/every_dream_validation.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index ddcb52c..0cb9fd7 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -40,15 +40,26 @@ class EveryDreamValidator: self.log_writer = log_writer - self.config = {} + self.config = { + 'batch_size': default_batch_size, + 'every_n_epochs': 1, + 'seed': 555 + } if val_config_path is not None: with open(val_config_path, 'rt') as f: - self.config = json.load(f) + self.config.update(json.load(f)) - self.batch_size = self.config.get('batch_size', default_batch_size) - self.every_n_epochs = self.config.get('every_n_epochs', 1) - self.seed = self.config.get('seed', 555) - self.val_data_root = self.config.get('val_data_root', None) + @property + def batch_size(self): + return self.config['batch_size'] + + @property + def every_n_epochs(self): + return self.config['every_n_epochs'] + + @property + def seed(self): + return self.config['seed'] def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]: """ @@ -115,7 +126,8 @@ class EveryDreamValidator: flip_p=0.0, seed=self.seed, ) - val_items = resolver.resolve_root(self.val_data_root, args) + val_data_root = self.config['val_data_root'] + val_items = resolver.resolve_root(val_data_root, args) else: raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val') From f0d7310c12604338884a16c2951e5e4f50ae0be5 Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 17:54:00 +0100 Subject: [PATCH 3/6] clarify init function names --- data/every_dream_validation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 0cb9fd7..ae4e6a2 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -69,10 +69,11 @@ class EveryDreamValidator: Otherwise, the returned `list` is identical to the passed-in `train_items`. """ with isolate_rng(): - self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer) + self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer) # order is important - if we're removing images from train, this needs to happen before making # the overlapping dataloader - self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer) + self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader_if_required( + remaining_train_items, tokenizer) return remaining_train_items def do_validation_if_appropriate(self, epoch: int, global_step: int, @@ -111,7 +112,7 @@ class EveryDreamValidator: loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) - def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\ + def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: val_split_mode = self.config.get('val_split_mode', 'automatic') val_split_proportion = self.config.get('val_split_proportion', 0.15) @@ -134,7 +135,7 @@ class EveryDreamValidator: val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size) return val_dataloader, remaining_train_items - def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \ + def _build_train_stabiliser_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \ -> Optional[torch.utils.data.DataLoader]: stabilize_training_loss = self.config.get('stabilize_training_loss', False) if not stabilize_training_loss: From dad9e347ffc93e130f5015a4791f94c079b1a3eb Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 18:08:19 +0100 Subject: [PATCH 4/6] log ed batch name on creation --- data/every_dream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/every_dream.py b/data/every_dream.py index 2d93a63..30f10aa 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -63,7 +63,7 @@ class EveryDreamBatch(Dataset): self.name = name num_images = len(self.image_train_items) - logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") + logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}") def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]: items = self.data_loader.get_random_split(split_proportion, remove_from_dataset) From e2d9600e344ee328033c5757df188c7d049a7ddf Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 18:18:21 +0100 Subject: [PATCH 5/6] cleaner config handling --- data/every_dream_validation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index ae4e6a2..ec6e51e 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -43,7 +43,13 @@ class EveryDreamValidator: self.config = { 'batch_size': default_batch_size, 'every_n_epochs': 1, - 'seed': 555 + 'seed': 555, + + 'val_split_mode': 'automatic', + 'val_split_proportion': 0.15, + + 'stabilize_training_loss': False, + 'stabilize_split_proportion': 0.15 } if val_config_path is not None: with open(val_config_path, 'rt') as f: @@ -114,8 +120,8 @@ class EveryDreamValidator: def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: - val_split_mode = self.config.get('val_split_mode', 'automatic') - val_split_proportion = self.config.get('val_split_proportion', 0.15) + val_split_mode = self.config['val_split_mode'] + val_split_proportion = self.config['val_split_proportion'] remaining_train_items = image_train_items if val_split_mode == 'none': return None, image_train_items @@ -137,11 +143,11 @@ class EveryDreamValidator: def _build_train_stabiliser_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \ -> Optional[torch.utils.data.DataLoader]: - stabilize_training_loss = self.config.get('stabilize_training_loss', False) + stabilize_training_loss = self.config['stabilize_training_loss'] if not stabilize_training_loss: return None - stabilize_split_proportion = self.config.get('stabilize_split_proportion', 0.15) + stabilize_split_proportion = self.config['stabilize_split_proportion'] stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer, name='stabilize-train') From bca1e6e59417a9070257ae0c37a86d956033d0de Mon Sep 17 00:00:00 2001 From: damian Date: Tue, 7 Feb 2023 18:21:05 +0100 Subject: [PATCH 6/6] consistent spelling --- data/every_dream_validation.py | 8 ++++---- validation_default.json | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index ec6e51e..1302cc8 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -78,7 +78,7 @@ class EveryDreamValidator: self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer) # order is important - if we're removing images from train, this needs to happen before making # the overlapping dataloader - self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader_if_required( + self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required( remaining_train_items, tokenizer) return remaining_train_items @@ -141,15 +141,15 @@ class EveryDreamValidator: val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size) return val_dataloader, remaining_train_items - def _build_train_stabiliser_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \ + def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \ -> Optional[torch.utils.data.DataLoader]: stabilize_training_loss = self.config['stabilize_training_loss'] if not stabilize_training_loss: return None stabilize_split_proportion = self.config['stabilize_split_proportion'] - stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) - stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer, + stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) + stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer, name='stabilize-train') stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) return stabilize_dataloader diff --git a/validation_default.json b/validation_default.json index 2b626f2..d0579f5 100644 --- a/validation_default.json +++ b/validation_default.json @@ -4,7 +4,7 @@ "val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.", "val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.", "val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.", - "stabilize_training_loss": "If true, stabilise the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.", + "stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.", "stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.", "every_n_epochs": "How often to run validation (1=every epoch).", "seed": "The seed to use when running validation and stabilization passes." @@ -13,7 +13,7 @@ "val_split_mode": "automatic", "val_data_root": null, "val_split_proportion": 0.15, - "stabilize_training_loss": true, + "stabilize_training_loss": false, "stabilize_split_proportion": 0.15, "every_n_epochs": 1, "seed": 555