better config handling
This commit is contained in:
parent
29396ec21b
commit
c3d844a1bc
|
@ -40,15 +40,26 @@ class EveryDreamValidator:
|
|||
|
||||
self.log_writer = log_writer
|
||||
|
||||
self.config = {}
|
||||
self.config = {
|
||||
'batch_size': default_batch_size,
|
||||
'every_n_epochs': 1,
|
||||
'seed': 555
|
||||
}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
self.config = json.load(f)
|
||||
self.config.update(json.load(f))
|
||||
|
||||
self.batch_size = self.config.get('batch_size', default_batch_size)
|
||||
self.every_n_epochs = self.config.get('every_n_epochs', 1)
|
||||
self.seed = self.config.get('seed', 555)
|
||||
self.val_data_root = self.config.get('val_data_root', None)
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.config['batch_size']
|
||||
|
||||
@property
|
||||
def every_n_epochs(self):
|
||||
return self.config['every_n_epochs']
|
||||
|
||||
@property
|
||||
def seed(self):
|
||||
return self.config['seed']
|
||||
|
||||
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -115,7 +126,8 @@ class EveryDreamValidator:
|
|||
flip_p=0.0,
|
||||
seed=self.seed,
|
||||
)
|
||||
val_items = resolver.resolve_root(self.val_data_root, args)
|
||||
val_data_root = self.config['val_data_root']
|
||||
val_items = resolver.resolve_root(val_data_root, args)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
|
|
Loading…
Reference in New Issue