diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 19a82c4..8a1a4d1 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -97,10 +97,13 @@ class EveryDreamValidator: remaining_train_items, tokenizer) return remaining_train_items - def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]: + def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]: if self.every_n_epochs >= 1: - # last step only - return [epoch_length_steps-1] + if ((epoch+1) % self.every_n_epochs) == 0: + # last step only + return [epoch_length_steps-1] + else: + return [] # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) # validation happens after training: diff --git a/train.py b/train.py index 1df7422..ee467d6 100644 --- a/train.py +++ b/train.py @@ -802,9 +802,8 @@ def main(args): validation_steps = ( [] if validator is None - else validator.get_validation_step_indices(len(train_dataloader)) + else validator.get_validation_step_indices(epoch, len(train_dataloader)) ) - print(f"validation on steps {validation_steps}") for step, batch in enumerate(train_dataloader): step_start_time = time.time()