Merge pull request #161 from damian0815/val_partial_epochs
Feature: enable partial epoch support for validation
This commit is contained in:
commit
6c562d5e78
|
@ -149,18 +149,32 @@ class EveryDreamValidator:
|
|||
|
||||
return remaining_train_items
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||
if self.every_n_epochs >= 1:
|
||||
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||
# validation happens after training:
|
||||
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
||||
|
||||
def do_validation(self, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
for i, dataset in enumerate(self.validation_datasets):
|
||||
mean_loss = self._calculate_validation_loss(dataset.name,
|
||||
dataset.dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
||||
scalar_value=mean_loss,
|
||||
global_step=global_step)
|
||||
dataset.track_loss_trend(mean_loss)
|
||||
for i, dataset in enumerate(self.validation_datasets):
|
||||
mean_loss = self._calculate_validation_loss(dataset.name,
|
||||
dataset.dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
||||
scalar_value=mean_loss,
|
||||
global_step=global_step)
|
||||
dataset.track_loss_trend(mean_loss)
|
||||
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
|
|
|
@ -104,7 +104,7 @@ The config file has the following options:
|
|||
|
||||
#### General settings
|
||||
|
||||
* `every_n_epochs`: How often to run validation (1=every epoch).
|
||||
* `every_n_epochs`: How often to run validation. Specify either whole numbers, eg 1=every epoch (recommended default), 2=every second epoch, etc.; or floating point numbers between 0 and 1, eg 0.5=twice per epoch, 0.33=three times per epoch, etc.
|
||||
* `seed`: The seed to use when running validation passes, and also for picking subsets of the data to use with `automatic` val_split_mode and/or `stabilize_training_loss`.
|
||||
|
||||
#### Extra manual datasets
|
||||
|
|
15
train.py
15
train.py
|
@ -702,8 +702,8 @@ def main(args):
|
|||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
validator.do_validation(global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
|
||||
# the sample generator might be configured to generate samples before step 0
|
||||
if sample_generator.generate_pretrain_samples:
|
||||
|
@ -722,6 +722,11 @@ def main(args):
|
|||
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
validation_steps = (
|
||||
[] if validator is None
|
||||
else validator.get_validation_step_indices(epoch, len(train_dataloader))
|
||||
)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
|
@ -769,6 +774,9 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if validator and step in validation_steps:
|
||||
validator.do_validation(global_step, get_model_prediction_and_target)
|
||||
|
||||
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||
generate_samples(global_step=global_step, batch=batch)
|
||||
|
||||
|
@ -803,9 +811,6 @@ def main(args):
|
|||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
"extra_manual_datasets": "Dictionary of 'name':'path' pairs defining additional validation datasets to load and log. eg { 'santa_suit': '/path/to/captioned_santa_suit_images', 'flamingo_suit': '/path/to/flamingo_suit_images' }",
|
||||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch, 2=every second epoch; 0.5=twice per epoch, 0.33=three times per epoch, etc.).",
|
||||
"seed": "The seed to use when running validation and stabilization passes.",
|
||||
"use_relative_loss": "logs val/loss as negative relative to first pre-train val/loss value"
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue