fix every_n_epoch>1 logic and remove unnecessary log

This commit is contained in:
Damian Stewart 2023-03-10 23:13:53 +01:00
parent 47e90ad865
commit c913824979
2 changed files with 7 additions and 5 deletions

View File

@ -97,10 +97,13 @@ class EveryDreamValidator:
remaining_train_items, tokenizer)
return remaining_train_items
def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]:
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1:
# last step only
return [epoch_length_steps-1]
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only
return [epoch_length_steps-1]
else:
return []
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training:

View File

@ -802,9 +802,8 @@ def main(args):
validation_steps = (
[] if validator is None
else validator.get_validation_step_indices(len(train_dataloader))
else validator.get_validation_step_indices(epoch, len(train_dataloader))
)
print(f"validation on steps {validation_steps}")
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()