fix every_n_epoch>1 logic and remove unnecessary log
This commit is contained in:
parent
47e90ad865
commit
c913824979
|
@ -97,10 +97,13 @@ class EveryDreamValidator:
|
|||
remaining_train_items, tokenizer)
|
||||
return remaining_train_items
|
||||
|
||||
def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]:
|
||||
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||
if self.every_n_epochs >= 1:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
else:
|
||||
return []
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||
# validation happens after training:
|
||||
|
|
3
train.py
3
train.py
|
@ -802,9 +802,8 @@ def main(args):
|
|||
|
||||
validation_steps = (
|
||||
[] if validator is None
|
||||
else validator.get_validation_step_indices(len(train_dataloader))
|
||||
else validator.get_validation_step_indices(epoch, len(train_dataloader))
|
||||
)
|
||||
print(f"validation on steps {validation_steps}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
|
Loading…
Reference in New Issue