diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index fd4e962..19a82c4 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -102,7 +102,7 @@ class EveryDreamValidator: # last step only return [epoch_length_steps-1] # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps - num_divisions = min(epoch_length_steps, math.ceil(1/self.every_n_epochs)) + num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) # validation happens after training: # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 validate_every_n_steps = epoch_length_steps / num_divisions