clarify init function names
This commit is contained in:
parent
c3d844a1bc
commit
f0d7310c12
|
@ -69,10 +69,11 @@ class EveryDreamValidator:
|
||||||
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||||
"""
|
"""
|
||||||
with isolate_rng():
|
with isolate_rng():
|
||||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer)
|
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||||
# order is important - if we're removing images from train, this needs to happen before making
|
# order is important - if we're removing images from train, this needs to happen before making
|
||||||
# the overlapping dataloader
|
# the overlapping dataloader
|
||||||
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer)
|
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader_if_required(
|
||||||
|
remaining_train_items, tokenizer)
|
||||||
return remaining_train_items
|
return remaining_train_items
|
||||||
|
|
||||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||||
|
@ -111,7 +112,7 @@ class EveryDreamValidator:
|
||||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||||
|
|
||||||
def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||||
val_split_mode = self.config.get('val_split_mode', 'automatic')
|
val_split_mode = self.config.get('val_split_mode', 'automatic')
|
||||||
val_split_proportion = self.config.get('val_split_proportion', 0.15)
|
val_split_proportion = self.config.get('val_split_proportion', 0.15)
|
||||||
|
@ -134,7 +135,7 @@ class EveryDreamValidator:
|
||||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||||
return val_dataloader, remaining_train_items
|
return val_dataloader, remaining_train_items
|
||||||
|
|
||||||
def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
def _build_train_stabiliser_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||||
-> Optional[torch.utils.data.DataLoader]:
|
-> Optional[torch.utils.data.DataLoader]:
|
||||||
stabilize_training_loss = self.config.get('stabilize_training_loss', False)
|
stabilize_training_loss = self.config.get('stabilize_training_loss', False)
|
||||||
if not stabilize_training_loss:
|
if not stabilize_training_loss:
|
||||||
|
|
Loading…
Reference in New Issue