fix every_n_epoch>1 logic and remove unnecessary log
This commit is contained in:
parent
47e90ad865
commit
c913824979
|
@ -97,10 +97,13 @@ class EveryDreamValidator:
|
||||||
remaining_train_items, tokenizer)
|
remaining_train_items, tokenizer)
|
||||||
return remaining_train_items
|
return remaining_train_items
|
||||||
|
|
||||||
def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]:
|
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||||
if self.every_n_epochs >= 1:
|
if self.every_n_epochs >= 1:
|
||||||
# last step only
|
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||||
return [epoch_length_steps-1]
|
# last step only
|
||||||
|
return [epoch_length_steps-1]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||||
# validation happens after training:
|
# validation happens after training:
|
||||||
|
|
3
train.py
3
train.py
|
@ -802,9 +802,8 @@ def main(args):
|
||||||
|
|
||||||
validation_steps = (
|
validation_steps = (
|
||||||
[] if validator is None
|
[] if validator is None
|
||||||
else validator.get_validation_step_indices(len(train_dataloader))
|
else validator.get_validation_step_indices(epoch, len(train_dataloader))
|
||||||
)
|
)
|
||||||
print(f"validation on steps {validation_steps}")
|
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
Loading…
Reference in New Issue