diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 45882e0..a708df4 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -1,4 +1,3 @@ -import copy import json import logging import math @@ -64,6 +63,9 @@ class EveryDreamValidator: with open(val_config_path, 'rt') as f: self.config.update(json.load(f)) + self.train_overlapping_dataloader_loss_offset = None + self.val_loss_offset = None + self.loss_val_history = [] self.val_loss_window_size = 4 # todo: arg for this? @@ -100,28 +102,42 @@ class EveryDreamValidator: [Any, Any], tuple[torch.Tensor, torch.Tensor]]): if (epoch % self.every_n_epochs) == 0: if self.train_overlapping_dataloader is not None: - self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, - get_model_prediction_and_target_callable) + mean_loss = self._calculate_validation_loss('stabilize-train', + self.train_overlapping_dataloader, + get_model_prediction_and_target_callable) + if self.train_overlapping_dataloader_loss_offset is None: + self.train_overlapping_dataloader_loss_offset = -mean_loss + self.log_writer.add_scalar(tag=f"loss/stabilize-train", + scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss, + global_step=global_step) if self.val_dataloader is not None: - val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable) - - self.loss_val_history.append(val_loss) + mean_loss = self._calculate_validation_loss('val', + self.val_dataloader, + get_model_prediction_and_target_callable) + if self.val_loss_offset is None: + self.val_loss_offset = -mean_loss + self.log_writer.add_scalar(tag=f"loss/val", + scalar_value=self.val_loss_offset + mean_loss, + global_step=global_step) + self.loss_val_history.append(mean_loss) if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1): dy = np.diff(self.loss_val_history[-self.val_loss_window_size:]) if np.average(dy) > 0: logging.warning(f"Validation loss shows diverging") # todo: signal stop? - def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[ - [Any, Any], tuple[torch.Tensor, torch.Tensor]]): + def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[ + [Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float: with torch.no_grad(), isolate_rng(): + # ok to override seed here because we are in a `with isolate_rng():` block + random.seed(self.seed) + torch.manual_seed(self.seed) + loss_validation_epoch = [] steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}") for step, batch in enumerate(dataloader): - # ok to override seed here because we are in a `with isolate_rng():` block - torch.manual_seed(self.seed + step) model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -136,10 +152,8 @@ class EveryDreamValidator: steps_pbar.close() loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) - self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) - return loss_validation_local - + def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: @@ -153,13 +167,10 @@ class EveryDreamValidator: val_items = list(disable_multiplier_and_flip(val_items)) logging.info(f" * Removed {len(val_items)} images from the training set to use for validation") elif val_split_mode == 'manual': - args = Namespace( - aspects=aspects.get_aspect_buckets(self.resolution), - flip_p=0.0, - seed=self.seed, - ) - val_data_root = self.config['val_data_root'] - val_items = resolver.resolve_root(val_data_root, args) + val_data_root = self.config.get('val_data_root', None) + if val_data_root is None: + raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config") + val_items = self._load_manual_val_split(val_data_root) logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}") else: raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") @@ -181,6 +192,17 @@ class EveryDreamValidator: stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) return stabilize_dataloader + def _load_manual_val_split(self, val_data_root: str): + args = Namespace( + aspects=aspects.get_aspect_buckets(self.resolution), + flip_p=0.0, + seed=self.seed, + ) + val_items = resolver.resolve_root(val_data_root, args) + val_items.sort(key=lambda i: i.pathname) + random.shuffle(val_items) + return val_items + def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'): batch_size = self.batch_size seed = self.seed @@ -196,5 +218,6 @@ class EveryDreamValidator: tokenizer=tokenizer, seed=seed, name=name, + crop_jitter=0 ) return ed_batch