clarify init function names

This commit is contained in:
damian 2023-02-07 17:54:00 +01:00
parent c3d844a1bc
commit f0d7310c12
1 changed files with 5 additions and 4 deletions

View File

@ -69,10 +69,11 @@ class EveryDreamValidator:
Otherwise, the returned `list` is identical to the passed-in `train_items`. Otherwise, the returned `list` is identical to the passed-in `train_items`.
""" """
with isolate_rng(): with isolate_rng():
self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer) self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making # order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader # the overlapping dataloader
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer) self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader_if_required(
remaining_train_items, tokenizer)
return remaining_train_items return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int, def do_validation_if_appropriate(self, epoch: int, global_step: int,
@ -111,7 +112,7 @@ class EveryDreamValidator:
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\ def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config.get('val_split_mode', 'automatic') val_split_mode = self.config.get('val_split_mode', 'automatic')
val_split_proportion = self.config.get('val_split_proportion', 0.15) val_split_proportion = self.config.get('val_split_proportion', 0.15)
@ -134,7 +135,7 @@ class EveryDreamValidator:
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size) val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items return val_dataloader, remaining_train_items
def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \ def _build_train_stabiliser_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]: -> Optional[torch.utils.data.DataLoader]:
stabilize_training_loss = self.config.get('stabilize_training_loss', False) stabilize_training_loss = self.config.get('stabilize_training_loss', False)
if not stabilize_training_loss: if not stabilize_training_loss: