fix every_n_epoch>1 logic and remove unnecessary log

This commit is contained in:
Damian Stewart 2023-03-10 23:13:53 +01:00
parent 47e90ad865
commit c913824979
2 changed files with 7 additions and 5 deletions

View File

@ -97,10 +97,13 @@ class EveryDreamValidator:
remaining_train_items, tokenizer) remaining_train_items, tokenizer)
return remaining_train_items return remaining_train_items
def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]: def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1: if self.every_n_epochs >= 1:
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only # last step only
return [epoch_length_steps-1] return [epoch_length_steps-1]
else:
return []
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training: # validation happens after training:

View File

@ -802,9 +802,8 @@ def main(args):
validation_steps = ( validation_steps = (
[] if validator is None [] if validator is None
else validator.get_validation_step_indices(len(train_dataloader)) else validator.get_validation_step_indices(epoch, len(train_dataloader))
) )
print(f"validation on steps {validation_steps}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
step_start_time = time.time() step_start_time = time.time()