diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index ddcb52c..0cb9fd7 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -40,15 +40,26 @@ class EveryDreamValidator: self.log_writer = log_writer - self.config = {} + self.config = { + 'batch_size': default_batch_size, + 'every_n_epochs': 1, + 'seed': 555 + } if val_config_path is not None: with open(val_config_path, 'rt') as f: - self.config = json.load(f) + self.config.update(json.load(f)) - self.batch_size = self.config.get('batch_size', default_batch_size) - self.every_n_epochs = self.config.get('every_n_epochs', 1) - self.seed = self.config.get('seed', 555) - self.val_data_root = self.config.get('val_data_root', None) + @property + def batch_size(self): + return self.config['batch_size'] + + @property + def every_n_epochs(self): + return self.config['every_n_epochs'] + + @property + def seed(self): + return self.config['seed'] def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]: """ @@ -115,7 +126,8 @@ class EveryDreamValidator: flip_p=0.0, seed=self.seed, ) - val_items = resolver.resolve_root(self.val_data_root, args) + val_data_root = self.config['val_data_root'] + val_items = resolver.resolve_root(val_data_root, args) else: raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')