Merge pull request #82 from damian0815/fix_validation_seeding
actually use the validation random seed
This commit is contained in:
commit
a8cc62fb94
|
@ -23,7 +23,7 @@ from utils.isolate_rng import isolate_rng
|
|||
|
||||
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||
split_item_count = math.ceil(split_proportion * len(items) / batch_size) * batch_size
|
||||
split_item_count = max(1, math.ceil(split_proportion * len(items)))
|
||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
||||
random.shuffle(items_copy)
|
||||
|
@ -83,6 +83,7 @@ class EveryDreamValidator:
|
|||
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||
"""
|
||||
with isolate_rng():
|
||||
random.seed(self.seed)
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
|
|
|
@ -31,7 +31,7 @@ class SampleRequest:
|
|||
def __str__(self):
|
||||
rep = self.prompt
|
||||
if len(self.negative_prompt) > 0:
|
||||
rep += "\n negative prompt: {self.negative_prompt}"
|
||||
rep += f"\n negative prompt: {self.negative_prompt}"
|
||||
rep += f"\n seed: {self.seed}"
|
||||
return rep
|
||||
|
||||
|
|
Loading…
Reference in New Issue