make validation more comparable across runs

This commit is contained in:
Damian Stewart 2023-03-09 08:21:59 +01:00
parent 4c996cb6b5
commit f0111a6e2b
1 changed files with 43 additions and 20 deletions

View File

@ -1,4 +1,3 @@
import copy
import json
import logging
import math
@ -64,6 +63,9 @@ class EveryDreamValidator:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
self.train_overlapping_dataloader_loss_offset = None
self.val_loss_offset = None
self.loss_val_history = []
self.val_loss_window_size = 4 # todo: arg for this?
@ -100,28 +102,42 @@ class EveryDreamValidator:
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
mean_loss = self._calculate_validation_loss('stabilize-train',
self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.train_overlapping_dataloader_loss_offset is None:
self.train_overlapping_dataloader_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
global_step=global_step)
if self.val_dataloader is not None:
val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
self.loss_val_history.append(val_loss)
mean_loss = self._calculate_validation_loss('val',
self.val_dataloader,
get_model_prediction_and_target_callable)
if self.val_loss_offset is None:
self.val_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/val",
scalar_value=self.val_loss_offset + mean_loss,
global_step=global_step)
self.loss_val_history.append(mean_loss)
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss shows diverging")
# todo: signal stop?
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
with torch.no_grad(), isolate_rng():
# ok to override seed here because we are in a `with isolate_rng():` block
random.seed(self.seed)
torch.manual_seed(self.seed)
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):
# ok to override seed here because we are in a `with isolate_rng():` block
torch.manual_seed(self.seed + step)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@ -136,10 +152,8 @@ class EveryDreamValidator:
steps_pbar.close()
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
return loss_validation_local
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
@ -153,13 +167,10 @@ class EveryDreamValidator:
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
elif val_split_mode == 'manual':
args = Namespace(
aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0,
seed=self.seed,
)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
val_data_root = self.config.get('val_data_root', None)
if val_data_root is None:
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
val_items = self._load_manual_val_split(val_data_root)
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
@ -181,6 +192,17 @@ class EveryDreamValidator:
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return stabilize_dataloader
def _load_manual_val_split(self, val_data_root: str):
args = Namespace(
aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0,
seed=self.seed,
)
val_items = resolver.resolve_root(val_data_root, args)
val_items.sort(key=lambda i: i.pathname)
random.shuffle(val_items)
return val_items
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
@ -196,5 +218,6 @@ class EveryDreamValidator:
tokenizer=tokenizer,
seed=seed,
name=name,
crop_jitter=0
)
return ed_batch