clarify init function names
This commit is contained in:
parent
c3d844a1bc
commit
f0d7310c12
|
@ -69,10 +69,11 @@ class EveryDreamValidator:
|
|||
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||
"""
|
||||
with isolate_rng():
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer)
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer)
|
||||
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader_if_required(
|
||||
remaining_train_items, tokenizer)
|
||||
return remaining_train_items
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
|
@ -111,7 +112,7 @@ class EveryDreamValidator:
|
|||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||
|
||||
def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config.get('val_split_mode', 'automatic')
|
||||
val_split_proportion = self.config.get('val_split_proportion', 0.15)
|
||||
|
@ -134,7 +135,7 @@ class EveryDreamValidator:
|
|||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return val_dataloader, remaining_train_items
|
||||
|
||||
def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
def _build_train_stabiliser_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> Optional[torch.utils.data.DataLoader]:
|
||||
stabilize_training_loss = self.config.get('stabilize_training_loss', False)
|
||||
if not stabilize_training_loss:
|
||||
|
|
Loading…
Reference in New Issue