From 103ab20696a5f881fc1657d80b127f3487b97028 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 9 Mar 2023 20:42:17 +0100 Subject: [PATCH 1/4] allow sub-epoch validation when every_n_epochs <1 --- data/every_dream_validation.py | 64 ++++++++++++++++++++-------------- train.py | 16 ++++++--- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index a708df4..fd4e962 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -97,34 +97,44 @@ class EveryDreamValidator: remaining_train_items, tokenizer) return remaining_train_items - def do_validation_if_appropriate(self, epoch: int, global_step: int, - get_model_prediction_and_target_callable: Callable[ + def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]: + if self.every_n_epochs >= 1: + # last step only + return [epoch_length_steps-1] + # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps + num_divisions = min(epoch_length_steps, math.ceil(1/self.every_n_epochs)) + # validation happens after training: + # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 + validate_every_n_steps = epoch_length_steps / num_divisions + return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)] + + def do_validation(self, global_step: int, + get_model_prediction_and_target_callable: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]): - if (epoch % self.every_n_epochs) == 0: - if self.train_overlapping_dataloader is not None: - mean_loss = self._calculate_validation_loss('stabilize-train', - self.train_overlapping_dataloader, - get_model_prediction_and_target_callable) - if self.train_overlapping_dataloader_loss_offset is None: - self.train_overlapping_dataloader_loss_offset = -mean_loss - self.log_writer.add_scalar(tag=f"loss/stabilize-train", - scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss, - global_step=global_step) - if self.val_dataloader is not None: - mean_loss = self._calculate_validation_loss('val', - self.val_dataloader, - get_model_prediction_and_target_callable) - if self.val_loss_offset is None: - self.val_loss_offset = -mean_loss - self.log_writer.add_scalar(tag=f"loss/val", - scalar_value=self.val_loss_offset + mean_loss, - global_step=global_step) - self.loss_val_history.append(mean_loss) - if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1): - dy = np.diff(self.loss_val_history[-self.val_loss_window_size:]) - if np.average(dy) > 0: - logging.warning(f"Validation loss shows diverging") - # todo: signal stop? + if self.train_overlapping_dataloader is not None: + mean_loss = self._calculate_validation_loss('stabilize-train', + self.train_overlapping_dataloader, + get_model_prediction_and_target_callable) + if self.train_overlapping_dataloader_loss_offset is None: + self.train_overlapping_dataloader_loss_offset = -mean_loss + self.log_writer.add_scalar(tag=f"loss/stabilize-train", + scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss, + global_step=global_step) + if self.val_dataloader is not None: + mean_loss = self._calculate_validation_loss('val', + self.val_dataloader, + get_model_prediction_and_target_callable) + if self.val_loss_offset is None: + self.val_loss_offset = -mean_loss + self.log_writer.add_scalar(tag=f"loss/val", + scalar_value=self.val_loss_offset + mean_loss, + global_step=global_step) + self.loss_val_history.append(mean_loss) + if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1): + dy = np.diff(self.loss_val_history[-self.val_loss_window_size:]) + if np.average(dy) > 0: + logging.warning(f"Validation loss shows diverging") + # todo: signal stop? def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float: diff --git a/train.py b/train.py index 4d06ce1..1df7422 100644 --- a/train.py +++ b/train.py @@ -780,8 +780,8 @@ def main(args): # Pre-train validation to establish a starting point on the loss graph if validator: - validator.do_validation_if_appropriate(epoch=0, global_step=0, - get_model_prediction_and_target_callable=get_model_prediction_and_target) + validator.do_validation(global_step=0, + get_model_prediction_and_target_callable=get_model_prediction_and_target) # the sample generator might be configured to generate samples before step 0 if sample_generator.generate_pretrain_samples: @@ -800,6 +800,12 @@ def main(args): steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") + validation_steps = ( + [] if validator is None + else validator.get_validation_step_indices(len(train_dataloader)) + ) + print(f"validation on steps {validation_steps}") + for step, batch in enumerate(train_dataloader): step_start_time = time.time() @@ -860,6 +866,9 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) torch.cuda.empty_cache() + if validator and step in validation_steps: + validator.do_validation(global_step, get_model_prediction_and_target) + if (global_step + 1) % sample_generator.sample_steps == 0: generate_samples(global_step=global_step, batch=batch) @@ -895,9 +904,6 @@ def main(args): loss_local = sum(loss_epoch) / len(loss_epoch) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) - if validator: - validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target) - gc.collect() # end of epoch From 47e90ad865051851b32b297c4ce7bb3c98ac4a48 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 9 Mar 2023 20:47:37 +0100 Subject: [PATCH 2/4] round rather than ceil --- data/every_dream_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index fd4e962..19a82c4 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -102,7 +102,7 @@ class EveryDreamValidator: # last step only return [epoch_length_steps-1] # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps - num_divisions = min(epoch_length_steps, math.ceil(1/self.every_n_epochs)) + num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) # validation happens after training: # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 validate_every_n_steps = epoch_length_steps / num_divisions From c913824979af88f159f59147dfa73d896d58e0a7 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 10 Mar 2023 23:13:53 +0100 Subject: [PATCH 3/4] fix every_n_epoch>1 logic and remove unnecessary log --- data/every_dream_validation.py | 9 ++++++--- train.py | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 19a82c4..8a1a4d1 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -97,10 +97,13 @@ class EveryDreamValidator: remaining_train_items, tokenizer) return remaining_train_items - def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]: + def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]: if self.every_n_epochs >= 1: - # last step only - return [epoch_length_steps-1] + if ((epoch+1) % self.every_n_epochs) == 0: + # last step only + return [epoch_length_steps-1] + else: + return [] # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) # validation happens after training: diff --git a/train.py b/train.py index 1df7422..ee467d6 100644 --- a/train.py +++ b/train.py @@ -802,9 +802,8 @@ def main(args): validation_steps = ( [] if validator is None - else validator.get_validation_step_indices(len(train_dataloader)) + else validator.get_validation_step_indices(epoch, len(train_dataloader)) ) - print(f"validation on steps {validation_steps}") for step, batch in enumerate(train_dataloader): step_start_time = time.time() From 8c05b7e1d55c00f7c4c9e018606c0ecba31ea0ab Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 7 May 2023 12:05:49 +0200 Subject: [PATCH 4/4] update docs for every_n_epochs --- doc/VALIDATION.md | 2 +- validation_default.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/VALIDATION.md b/doc/VALIDATION.md index 3a9860f..f2a2fae 100644 --- a/doc/VALIDATION.md +++ b/doc/VALIDATION.md @@ -104,7 +104,7 @@ The config file has the following options: #### General settings -* `every_n_epochs`: How often to run validation (1=every epoch). +* `every_n_epochs`: How often to run validation. Specify either whole numbers, eg 1=every epoch (recommended default), 2=every second epoch, etc.; or floating point numbers between 0 and 1, eg 0.5=twice per epoch, 0.33=three times per epoch, etc. * `seed`: The seed to use when running validation passes, and also for picking subsets of the data to use with `automatic` val_split_mode and/or `stabilize_training_loss`. #### Extra manual datasets diff --git a/validation_default.json b/validation_default.json index a351a75..ce95b80 100644 --- a/validation_default.json +++ b/validation_default.json @@ -7,7 +7,7 @@ "extra_manual_datasets": "Dictionary of 'name':'path' pairs defining additional validation datasets to load and log. eg { 'santa_suit': '/path/to/captioned_santa_suit_images', 'flamingo_suit': '/path/to/flamingo_suit_images' }", "stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.", "stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.", - "every_n_epochs": "How often to run validation (1=every epoch).", + "every_n_epochs": "How often to run validation (1=every epoch, 2=every second epoch; 0.5=twice per epoch, 0.33=three times per epoch, etc.).", "seed": "The seed to use when running validation and stabilization passes.", "use_relative_loss": "logs val/loss as negative relative to first pre-train val/loss value" },