better config handling

This commit is contained in:
damian 2023-02-07 17:52:23 +01:00
parent 29396ec21b
commit c3d844a1bc
1 changed files with 19 additions and 7 deletions

View File

@ -40,15 +40,26 @@ class EveryDreamValidator:
self.log_writer = log_writer
self.config = {}
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
self.config = json.load(f)
self.config.update(json.load(f))
self.batch_size = self.config.get('batch_size', default_batch_size)
self.every_n_epochs = self.config.get('every_n_epochs', 1)
self.seed = self.config.get('seed', 555)
self.val_data_root = self.config.get('val_data_root', None)
@property
def batch_size(self):
return self.config['batch_size']
@property
def every_n_epochs(self):
return self.config['every_n_epochs']
@property
def seed(self):
return self.config['seed']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
@ -115,7 +126,8 @@ class EveryDreamValidator:
flip_p=0.0,
seed=self.seed,
)
val_items = resolver.resolve_root(self.val_data_root, args)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')