make validation more comparable across runs
This commit is contained in:
parent
4c996cb6b5
commit
f0111a6e2b
|
@ -1,4 +1,3 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
|
@ -64,6 +63,9 @@ class EveryDreamValidator:
|
|||
with open(val_config_path, 'rt') as f:
|
||||
self.config.update(json.load(f))
|
||||
|
||||
self.train_overlapping_dataloader_loss_offset = None
|
||||
self.val_loss_offset = None
|
||||
|
||||
self.loss_val_history = []
|
||||
self.val_loss_window_size = 4 # todo: arg for this?
|
||||
|
||||
|
@ -100,28 +102,42 @@ class EveryDreamValidator:
|
|||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
||||
mean_loss = self._calculate_validation_loss('stabilize-train',
|
||||
self.train_overlapping_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.train_overlapping_dataloader_loss_offset is None:
|
||||
self.train_overlapping_dataloader_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
|
||||
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
if self.val_dataloader is not None:
|
||||
val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||
|
||||
self.loss_val_history.append(val_loss)
|
||||
mean_loss = self._calculate_validation_loss('val',
|
||||
self.val_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_loss_offset is None:
|
||||
self.val_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/val",
|
||||
scalar_value=self.val_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
self.loss_val_history.append(mean_loss)
|
||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging")
|
||||
# todo: signal stop?
|
||||
|
||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
with torch.no_grad(), isolate_rng():
|
||||
# ok to override seed here because we are in a `with isolate_rng():` block
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
loss_validation_epoch = []
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
# ok to override seed here because we are in a `with isolate_rng():` block
|
||||
torch.manual_seed(self.seed + step)
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
@ -136,8 +152,6 @@ class EveryDreamValidator:
|
|||
steps_pbar.close()
|
||||
|
||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||
|
||||
return loss_validation_local
|
||||
|
||||
|
||||
|
@ -153,13 +167,10 @@ class EveryDreamValidator:
|
|||
val_items = list(disable_multiplier_and_flip(val_items))
|
||||
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||
elif val_split_mode == 'manual':
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(self.resolution),
|
||||
flip_p=0.0,
|
||||
seed=self.seed,
|
||||
)
|
||||
val_data_root = self.config['val_data_root']
|
||||
val_items = resolver.resolve_root(val_data_root, args)
|
||||
val_data_root = self.config.get('val_data_root', None)
|
||||
if val_data_root is None:
|
||||
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
|
||||
val_items = self._load_manual_val_split(val_data_root)
|
||||
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
|
@ -181,6 +192,17 @@ class EveryDreamValidator:
|
|||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
return stabilize_dataloader
|
||||
|
||||
def _load_manual_val_split(self, val_data_root: str):
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(self.resolution),
|
||||
flip_p=0.0,
|
||||
seed=self.seed,
|
||||
)
|
||||
val_items = resolver.resolve_root(val_data_root, args)
|
||||
val_items.sort(key=lambda i: i.pathname)
|
||||
random.shuffle(val_items)
|
||||
return val_items
|
||||
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||
batch_size = self.batch_size
|
||||
seed = self.seed
|
||||
|
@ -196,5 +218,6 @@ class EveryDreamValidator:
|
|||
tokenizer=tokenizer,
|
||||
seed=seed,
|
||||
name=name,
|
||||
crop_jitter=0
|
||||
)
|
||||
return ed_batch
|
||||
|
|
Loading…
Reference in New Issue