diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 2f92275..95f5afd 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -23,7 +23,7 @@ from utils.isolate_rng import isolate_rng def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \ -> tuple[list[ImageTrainItem], list[ImageTrainItem]]: - split_item_count = math.ceil(split_proportion * len(items) / batch_size) * batch_size + split_item_count = max(1, math.ceil(split_proportion * len(items))) # sort first, then shuffle, to ensure determinate outcome for the current random state items_copy = list(sorted(items, key=lambda i: i.pathname)) random.shuffle(items_copy) @@ -83,6 +83,7 @@ class EveryDreamValidator: Otherwise, the returned `list` is identical to the passed-in `train_items`. """ with isolate_rng(): + random.seed(self.seed) self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer) # order is important - if we're removing images from train, this needs to happen before making # the overlapping dataloader diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 609a8e6..ffded1b 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -31,7 +31,7 @@ class SampleRequest: def __str__(self): rep = self.prompt if len(self.negative_prompt) > 0: - rep += "\n negative prompt: {self.negative_prompt}" + rep += f"\n negative prompt: {self.negative_prompt}" rep += f"\n seed: {self.seed}" return rep