Merge pull request #1 from damian0815/add-support-for-val-split

update & cleanup EveryDreamValidator
This commit is contained in:
Joel Holdbrooks 2023-02-07 10:53:13 -08:00 committed by GitHub
commit fc79413224
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 168 additions and 193 deletions

View File

@ -169,24 +169,6 @@ class DataLoaderMultiAspect():
return picked_images return picked_images
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = items_copy[:item_count]
if remove_from_dataset:
self.delete_items(split_items)
return split_items
def delete_items(self, items: list[ImageTrainItem]):
for item in items:
for i, other_item in enumerate(self.prepared_train_data):
if other_item.pathname == item.pathname:
self.prepared_train_data.pop(i)
break
self.__update_rating_sums()
def __update_rating_sums(self): def __update_rating_sums(self):
self.rating_overall_sum: float = 0.0 self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = [] self.ratings_summed: list[float] = []

View File

@ -63,7 +63,7 @@ class EveryDreamBatch(Dataset):
self.name = name self.name = name
num_images = len(self.image_train_items) num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]: def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset) items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset):
def __update_image_train_items(self, dropout_fraction: float): def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader: def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
items, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=0, num_workers=4,
collate_fn=collate_fn collate_fn=collate_fn
) )
return dataloader return dataloader
@ -173,4 +173,4 @@ def collate_fn(batch):
"runt_size": runt_size, "runt_size": runt_size,
} }
del batch del batch
return ret return ret

View File

@ -1,4 +1,5 @@
import json import json
import math
import random import random
from typing import Callable, Any, Optional from typing import Callable, Any, Optional
from argparse import Namespace from argparse import Namespace
@ -14,57 +15,85 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect from data.data_loader import DataLoaderMultiAspect
from data import resolver from data import resolver
from data import aspects from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng from utils.isolate_rng import isolate_rng
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = list(items_copy[:split_item_count])
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
class EveryDreamValidator: class EveryDreamValidator:
def __init__(self, def __init__(self,
val_config_path: Optional[str], val_config_path: Optional[str],
train_batch: EveryDreamBatch, default_batch_size: int,
log_writer: SummaryWriter): log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer self.log_writer = log_writer
val_config = {} self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15
}
if val_config_path is not None: if val_config_path is not None:
with open(val_config_path, 'rt') as f: with open(val_config_path, 'rt') as f:
val_config = json.load(f) self.config.update(json.load(f))
do_validation = val_config.get('validate_training', False) @property
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none' def batch_size(self):
self.val_data_root = val_config.get('val_data_root', None) return self.config['batch_size']
val_split_proportion = val_config.get('val_split_proportion', 0.15)
stabilize_training_loss = val_config.get('stabilize_training_loss', False) @property
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15) def every_n_epochs(self):
return self.config['every_n_epochs']
self.every_n_epochs = val_config.get('every_n_epochs', 1) @property
self.seed = val_config.get('seed', 555) def seed(self):
return self.config['seed']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
Build the validation splits as requested by the config passed at init.
This may steal some items from `train_items`.
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
Otherwise, the returned `list` is identical to the passed-in `train_items`.
"""
with isolate_rng(): with isolate_rng():
self.val_dataloader = self._build_validation_dataloader(val_split_mode, self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
split_proportion=val_split_proportion,
val_data_root=self.val_data_root,
train_batch=train_batch)
# order is important - if we're removing images from train, this needs to happen before making # order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader # the overlapping dataloader
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch, self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
split_proportion=stabilize_split_proportion, remaining_train_items, tokenizer)
name='train-stabilizer', return remaining_train_items
enforce_split=False) if stabilize_training_loss else None
def do_validation_if_appropriate(self, epoch: int, global_step: int, def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[ get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]): [Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0: if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None: if self.train_overlapping_dataloader is not None:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable) self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.val_dataloader is not None: if self.val_dataloader is not None:
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable) self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[ def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]): [Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng(): with torch.no_grad(), isolate_rng():
loss_validation_epoch = [] loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1) steps_pbar = tqdm(range(len(dataloader)), position=1)
@ -75,8 +104,6 @@ class EveryDreamValidator:
torch.manual_seed(self.seed + step) torch.manual_seed(self.seed + step)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"]) model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
# del timesteps, encoder_hidden_states, noisy_latents
# with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred del target, model_pred
@ -91,80 +118,56 @@ class EveryDreamValidator:
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
def _build_validation_dataloader(self, -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
validation_split_mode: str, val_split_mode = self.config['val_split_mode']
split_proportion: float, val_split_proportion = self.config['val_split_proportion']
val_data_root: Optional[str], remaining_train_items = image_train_items
train_batch: EveryDreamBatch) -> Optional[DataLoader]: if val_split_mode == 'none':
if validation_split_mode == 'none': return None, image_train_items
return None elif val_split_mode == 'automatic':
elif validation_split_mode == 'automatic': val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True) elif val_split_mode == 'manual':
elif validation_split_mode == 'manual': args = Namespace(
if val_data_root is None: aspects=aspects.get_aspect_buckets(512),
raise ValueError("val_data_root is required for 'manual' validation split mode") flip_p=0.0,
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch) seed=self.seed,
)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
else: else:
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'") raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]:
stabilize_training_loss = self.config['stabilize_training_loss']
if not stabilize_training_loss:
return None
def _build_dataloader_from_automatic_split(self, stabilize_split_proportion = self.config['stabilize_split_proportion']
train_batch: EveryDreamBatch, stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
split_proportion: float, stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name: str, name='stabilize-train')
enforce_split: bool=False) -> DataLoader: stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
""" return stabilize_dataloader
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
"""
with isolate_rng():
random.seed(self.seed)
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
if enforce_split:
print(
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
if len(train_batch) == 0:
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
val_batch = self._make_val_batch_with_train_batch_settings(
val_items,
train_batch,
name=name
)
return build_torch_dataloader(
items=val_batch,
batch_size=train_batch.batch_size,
)
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader: batch_size = self.batch_size
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch) seed = self.seed
return build_torch_dataloader(
items=val_batch,
batch_size=reference_train_batch.batch_size
)
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
batch_size = reference_train_batch.batch_size
seed = reference_train_batch.seed
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.0,
seed=seed,
)
image_train_items = resolver.resolve(data_root, args)
data_loader = DataLoaderMultiAspect( data_loader = DataLoaderMultiAspect(
image_train_items, items,
batch_size=batch_size, batch_size=batch_size,
seed=seed, seed=seed,
) )
val_batch = EveryDreamBatch( ed_batch = EveryDreamBatch(
data_loader=data_loader, data_loader=data_loader,
debug_level=1, debug_level=1,
batch_size=batch_size,
conditional_dropout=0, conditional_dropout=0,
tokenizer=reference_train_batch.tokenizer, tokenizer=tokenizer,
seed=seed, seed=seed,
name=name, name=name,
) )
return val_batch return ed_batch

150
train.py
View File

@ -52,7 +52,8 @@ import wandb
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from data.data_loader import DataLoaderMultiAspect from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
return sample_prompts return sample_prompts
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
def main(args): def main(args):
""" """
Main entry point Main entry point
@ -387,10 +365,13 @@ def main(args):
seed = args.seed if args.seed != -1 else random.randint(0, 2**30) seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}") logging.info(f" Seed: {seed}")
set_seed(seed) set_seed(seed)
gpu = GPU() if torch.cuda.is_available():
device = torch.device(f"cuda:{args.gpuid}") gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
device = 'cpu'
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
@ -606,6 +587,11 @@ def main(args):
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
betas = (0.9, 0.999) betas = (0.9, 0.999)
epsilon = 1e-8 epsilon = 1e-8
if args.amp: if args.amp:
@ -630,9 +616,14 @@ def main(args):
) )
log_optimizer(optimizer, betas, epsilon) log_optimizer(optimizer, betas, epsilon)
image_train_items = resolve_image_train_items(args, log_folder) image_train_items = resolve_image_train_items(args, log_folder)
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect( data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items, image_train_items=image_train_items,
seed=seed, seed=seed,
@ -668,12 +659,7 @@ def main(args):
if args.wandb is not None and args.wandb: if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, ) wandb.init(project=args.project_name, sync_tensorboard=True, )
log_writer = SummaryWriter(
log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
def log_args(log_writer, args): def log_args(log_writer, args):
arglog = "args:\n" arglog = "args:\n"
@ -729,15 +715,7 @@ def main(args):
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
train_dataloader = torch.utils.data.DataLoader(
train_batch,
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
collate_fn=collate_fn,
pin_memory=True
)
unet.train() if not args.disable_unet_training else unet.eval() unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval() text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
@ -775,7 +753,49 @@ def main(args):
loss_log_step = [] loss_log_step = []
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = tokens.to(text_encoder.device)
# with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
return model_pred, target
try: try:
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py # # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution])) # pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
@ -809,41 +829,7 @@ def main(args):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
step_start_time = time.time() step_start_time = time.time()
with torch.no_grad(): model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
with autocast(enabled=args.amp):
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = batch["tokens"].to(text_encoder.device)
#with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
#del timesteps, encoder_hidden_states, noisy_latents #del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp): #with autocast(enabled=args.amp):
@ -952,6 +938,10 @@ def main(args):
loss_local = sum(loss_epoch) / len(loss_epoch) loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect() gc.collect()
# end of epoch # end of epoch

View File

@ -4,7 +4,7 @@
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.", "val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.", "val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.", "val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"stabilize_training_loss": "If true, stabilise the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.", "stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.", "stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).", "every_n_epochs": "How often to run validation (1=every epoch).",
"seed": "The seed to use when running validation and stabilization passes." "seed": "The seed to use when running validation and stabilization passes."
@ -13,7 +13,7 @@
"val_split_mode": "automatic", "val_split_mode": "automatic",
"val_data_root": null, "val_data_root": null,
"val_split_proportion": 0.15, "val_split_proportion": 0.15,
"stabilize_training_loss": true, "stabilize_training_loss": false,
"stabilize_split_proportion": 0.15, "stabilize_split_proportion": 0.15,
"every_n_epochs": 1, "every_n_epochs": 1,
"seed": 555 "seed": 555