make validation more comparable across runs
This commit is contained in:
parent
4c996cb6b5
commit
f0111a6e2b
|
@ -1,4 +1,3 @@
|
||||||
import copy
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -64,6 +63,9 @@ class EveryDreamValidator:
|
||||||
with open(val_config_path, 'rt') as f:
|
with open(val_config_path, 'rt') as f:
|
||||||
self.config.update(json.load(f))
|
self.config.update(json.load(f))
|
||||||
|
|
||||||
|
self.train_overlapping_dataloader_loss_offset = None
|
||||||
|
self.val_loss_offset = None
|
||||||
|
|
||||||
self.loss_val_history = []
|
self.loss_val_history = []
|
||||||
self.val_loss_window_size = 4 # todo: arg for this?
|
self.val_loss_window_size = 4 # todo: arg for this?
|
||||||
|
|
||||||
|
@ -100,28 +102,42 @@ class EveryDreamValidator:
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
if (epoch % self.every_n_epochs) == 0:
|
if (epoch % self.every_n_epochs) == 0:
|
||||||
if self.train_overlapping_dataloader is not None:
|
if self.train_overlapping_dataloader is not None:
|
||||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
mean_loss = self._calculate_validation_loss('stabilize-train',
|
||||||
get_model_prediction_and_target_callable)
|
self.train_overlapping_dataloader,
|
||||||
|
get_model_prediction_and_target_callable)
|
||||||
|
if self.train_overlapping_dataloader_loss_offset is None:
|
||||||
|
self.train_overlapping_dataloader_loss_offset = -mean_loss
|
||||||
|
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
|
||||||
|
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
|
||||||
|
global_step=global_step)
|
||||||
if self.val_dataloader is not None:
|
if self.val_dataloader is not None:
|
||||||
val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
mean_loss = self._calculate_validation_loss('val',
|
||||||
|
self.val_dataloader,
|
||||||
self.loss_val_history.append(val_loss)
|
get_model_prediction_and_target_callable)
|
||||||
|
if self.val_loss_offset is None:
|
||||||
|
self.val_loss_offset = -mean_loss
|
||||||
|
self.log_writer.add_scalar(tag=f"loss/val",
|
||||||
|
scalar_value=self.val_loss_offset + mean_loss,
|
||||||
|
global_step=global_step)
|
||||||
|
self.loss_val_history.append(mean_loss)
|
||||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||||
if np.average(dy) > 0:
|
if np.average(dy) > 0:
|
||||||
logging.warning(f"Validation loss shows diverging")
|
logging.warning(f"Validation loss shows diverging")
|
||||||
# todo: signal stop?
|
# todo: signal stop?
|
||||||
|
|
||||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||||
with torch.no_grad(), isolate_rng():
|
with torch.no_grad(), isolate_rng():
|
||||||
|
# ok to override seed here because we are in a `with isolate_rng():` block
|
||||||
|
random.seed(self.seed)
|
||||||
|
torch.manual_seed(self.seed)
|
||||||
|
|
||||||
loss_validation_epoch = []
|
loss_validation_epoch = []
|
||||||
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
||||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||||
|
|
||||||
for step, batch in enumerate(dataloader):
|
for step, batch in enumerate(dataloader):
|
||||||
# ok to override seed here because we are in a `with isolate_rng():` block
|
|
||||||
torch.manual_seed(self.seed + step)
|
|
||||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||||
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
@ -136,10 +152,8 @@ class EveryDreamValidator:
|
||||||
steps_pbar.close()
|
steps_pbar.close()
|
||||||
|
|
||||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
|
||||||
|
|
||||||
return loss_validation_local
|
return loss_validation_local
|
||||||
|
|
||||||
|
|
||||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||||
|
@ -153,13 +167,10 @@ class EveryDreamValidator:
|
||||||
val_items = list(disable_multiplier_and_flip(val_items))
|
val_items = list(disable_multiplier_and_flip(val_items))
|
||||||
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||||
elif val_split_mode == 'manual':
|
elif val_split_mode == 'manual':
|
||||||
args = Namespace(
|
val_data_root = self.config.get('val_data_root', None)
|
||||||
aspects=aspects.get_aspect_buckets(self.resolution),
|
if val_data_root is None:
|
||||||
flip_p=0.0,
|
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
|
||||||
seed=self.seed,
|
val_items = self._load_manual_val_split(val_data_root)
|
||||||
)
|
|
||||||
val_data_root = self.config['val_data_root']
|
|
||||||
val_items = resolver.resolve_root(val_data_root, args)
|
|
||||||
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||||
|
@ -181,6 +192,17 @@ class EveryDreamValidator:
|
||||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||||
return stabilize_dataloader
|
return stabilize_dataloader
|
||||||
|
|
||||||
|
def _load_manual_val_split(self, val_data_root: str):
|
||||||
|
args = Namespace(
|
||||||
|
aspects=aspects.get_aspect_buckets(self.resolution),
|
||||||
|
flip_p=0.0,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
val_items = resolver.resolve_root(val_data_root, args)
|
||||||
|
val_items.sort(key=lambda i: i.pathname)
|
||||||
|
random.shuffle(val_items)
|
||||||
|
return val_items
|
||||||
|
|
||||||
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
seed = self.seed
|
seed = self.seed
|
||||||
|
@ -196,5 +218,6 @@ class EveryDreamValidator:
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
name=name,
|
name=name,
|
||||||
|
crop_jitter=0
|
||||||
)
|
)
|
||||||
return ed_batch
|
return ed_batch
|
||||||
|
|
Loading…
Reference in New Issue