Merge pull request #161 from damian0815/val_partial_epochs

Feature: enable partial epoch support for validation
This commit is contained in:
Victor Hall 2023-05-08 15:15:11 -04:00 committed by GitHub
commit 6c562d5e78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 18 deletions

View File

@ -149,18 +149,32 @@ class EveryDreamValidator:
return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1:
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only
return [epoch_length_steps-1]
else:
return []
else:
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training:
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
validate_every_n_steps = epoch_length_steps / num_divisions
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
def do_validation(self, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(dataset.name,
dataset.dataloader,
get_model_prediction_and_target_callable)
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
scalar_value=mean_loss,
global_step=global_step)
dataset.track_loss_trend(mean_loss)
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(dataset.name,
dataset.dataloader,
get_model_prediction_and_target_callable)
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
scalar_value=mean_loss,
global_step=global_step)
dataset.track_loss_trend(mean_loss)
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[

View File

@ -104,7 +104,7 @@ The config file has the following options:
#### General settings
* `every_n_epochs`: How often to run validation (1=every epoch).
* `every_n_epochs`: How often to run validation. Specify either whole numbers, eg 1=every epoch (recommended default), 2=every second epoch, etc.; or floating point numbers between 0 and 1, eg 0.5=twice per epoch, 0.33=three times per epoch, etc.
* `seed`: The seed to use when running validation passes, and also for picking subsets of the data to use with `automatic` val_split_mode and/or `stabilize_training_loss`.
#### Extra manual datasets

View File

@ -702,8 +702,8 @@ def main(args):
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation_if_appropriate(epoch=0, global_step=0,
get_model_prediction_and_target_callable=get_model_prediction_and_target)
validator.do_validation(global_step=0,
get_model_prediction_and_target_callable=get_model_prediction_and_target)
# the sample generator might be configured to generate samples before step 0
if sample_generator.generate_pretrain_samples:
@ -722,6 +722,11 @@ def main(args):
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
validation_steps = (
[] if validator is None
else validator.get_validation_step_indices(epoch, len(train_dataloader))
)
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
@ -769,6 +774,9 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if validator and step in validation_steps:
validator.do_validation(global_step, get_model_prediction_and_target)
if (global_step + 1) % sample_generator.sample_steps == 0:
generate_samples(global_step=global_step, batch=batch)
@ -803,9 +811,6 @@ def main(args):
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
if validator:
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch

View File

@ -7,7 +7,7 @@
"extra_manual_datasets": "Dictionary of 'name':'path' pairs defining additional validation datasets to load and log. eg { 'santa_suit': '/path/to/captioned_santa_suit_images', 'flamingo_suit': '/path/to/flamingo_suit_images' }",
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
"every_n_epochs": "How often to run validation (1=every epoch, 2=every second epoch; 0.5=twice per epoch, 0.33=three times per epoch, etc.).",
"seed": "The seed to use when running validation and stabilization passes.",
"use_relative_loss": "logs val/loss as negative relative to first pre-train val/loss value"
},