update EveryDreamValidator for noprompt's changes

This commit is contained in:
damian 2023-02-07 17:32:54 +01:00
parent 41c9f36ed7
commit 29396ec21b
4 changed files with 148 additions and 192 deletions

View File

@ -169,24 +169,6 @@ class DataLoaderMultiAspect():
return picked_images
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = items_copy[:item_count]
if remove_from_dataset:
self.delete_items(split_items)
return split_items
def delete_items(self, items: list[ImageTrainItem]):
for item in items:
for i, other_item in enumerate(self.prepared_train_data):
if other_item.pathname == item.pathname:
self.prepared_train_data.pop(i)
break
self.__update_rating_sums()
def __update_rating_sums(self):
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []

View File

@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset):
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader:
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
items,
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
num_workers=4,
collate_fn=collate_fn
)
return dataloader
@ -173,4 +173,4 @@ def collate_fn(batch):
"runt_size": runt_size,
}
del batch
return ret
return ret

View File

@ -1,4 +1,5 @@
import json
import math
import random
from typing import Callable, Any, Optional
from argparse import Namespace
@ -14,57 +15,67 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect
from data import resolver
from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = list(items_copy[:split_item_count])
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
train_batch: EveryDreamBatch,
default_batch_size: int,
log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
val_config = {}
self.config = {}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
val_config = json.load(f)
self.config = json.load(f)
do_validation = val_config.get('validate_training', False)
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none'
self.val_data_root = val_config.get('val_data_root', None)
val_split_proportion = val_config.get('val_split_proportion', 0.15)
stabilize_training_loss = val_config.get('stabilize_training_loss', False)
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15)
self.every_n_epochs = val_config.get('every_n_epochs', 1)
self.seed = val_config.get('seed', 555)
self.batch_size = self.config.get('batch_size', default_batch_size)
self.every_n_epochs = self.config.get('every_n_epochs', 1)
self.seed = self.config.get('seed', 555)
self.val_data_root = self.config.get('val_data_root', None)
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
Build the validation splits as requested by the config passed at init.
This may steal some items from `train_items`.
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
Otherwise, the returned `list` is identical to the passed-in `train_items`.
"""
with isolate_rng():
self.val_dataloader = self._build_validation_dataloader(val_split_mode,
split_proportion=val_split_proportion,
val_data_root=self.val_data_root,
train_batch=train_batch)
self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch,
split_proportion=stabilize_split_proportion,
name='train-stabilizer',
enforce_split=False) if stabilize_training_loss else None
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer)
return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable)
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.val_dataloader is not None:
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng():
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1)
@ -75,8 +86,6 @@ class EveryDreamValidator:
torch.manual_seed(self.seed + step)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
# del timesteps, encoder_hidden_states, noisy_latents
# with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
@ -91,80 +100,55 @@ class EveryDreamValidator:
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
def _build_validation_dataloader(self,
validation_split_mode: str,
split_proportion: float,
val_data_root: Optional[str],
train_batch: EveryDreamBatch) -> Optional[DataLoader]:
if validation_split_mode == 'none':
return None
elif validation_split_mode == 'automatic':
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True)
elif validation_split_mode == 'manual':
if val_data_root is None:
raise ValueError("val_data_root is required for 'manual' validation split mode")
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch)
def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config.get('val_split_mode', 'automatic')
val_split_proportion = self.config.get('val_split_proportion', 0.15)
remaining_train_items = image_train_items
if val_split_mode == 'none':
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
elif val_split_mode == 'manual':
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.0,
seed=self.seed,
)
val_items = resolver.resolve_root(self.val_data_root, args)
else:
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'")
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items
def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]:
stabilize_training_loss = self.config.get('stabilize_training_loss', False)
if not stabilize_training_loss:
return None
def _build_dataloader_from_automatic_split(self,
train_batch: EveryDreamBatch,
split_proportion: float,
name: str,
enforce_split: bool=False) -> DataLoader:
"""
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
"""
with isolate_rng():
random.seed(self.seed)
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
if enforce_split:
print(
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
if len(train_batch) == 0:
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
val_batch = self._make_val_batch_with_train_batch_settings(
val_items,
train_batch,
name=name
)
return build_torch_dataloader(
items=val_batch,
batch_size=train_batch.batch_size,
)
stabilize_split_proportion = self.config.get('stabilize_split_proportion', 0.15)
stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return stabilize_dataloader
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader:
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch)
return build_torch_dataloader(
items=val_batch,
batch_size=reference_train_batch.batch_size
)
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
batch_size = reference_train_batch.batch_size
seed = reference_train_batch.seed
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.0,
seed=seed,
)
image_train_items = resolver.resolve(data_root, args)
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
data_loader = DataLoaderMultiAspect(
image_train_items,
items,
batch_size=batch_size,
seed=seed,
)
val_batch = EveryDreamBatch(
ed_batch = EveryDreamBatch(
data_loader=data_loader,
debug_level=1,
batch_size=batch_size,
conditional_dropout=0,
tokenizer=reference_train_batch.tokenizer,
tokenizer=tokenizer,
seed=seed,
name=name,
)
return val_batch
return ed_batch

150
train.py
View File

@ -52,7 +52,8 @@ import wandb
from torch.utils.tensorboard import SummaryWriter
from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch
from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
return sample_prompts
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
def main(args):
"""
Main entry point
@ -387,10 +365,13 @@ def main(args):
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}")
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
device = 'cpu'
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
@ -606,6 +587,11 @@ def main(args):
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
betas = (0.9, 0.999)
epsilon = 1e-8
if args.amp:
@ -630,9 +616,14 @@ def main(args):
)
log_optimizer(optimizer, betas, epsilon)
image_train_items = resolve_image_train_items(args, log_folder)
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
seed=seed,
@ -668,12 +659,7 @@ def main(args):
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, )
log_writer = SummaryWriter(
log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
def log_args(log_writer, args):
arglog = "args:\n"
@ -729,15 +715,7 @@ def main(args):
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
train_dataloader = torch.utils.data.DataLoader(
train_batch,
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
collate_fn=collate_fn,
pin_memory=True
)
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
@ -775,7 +753,49 @@ def main(args):
loss_log_step = []
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = tokens.to(text_encoder.device)
# with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
return model_pred, target
try:
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
@ -809,41 +829,7 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = batch["tokens"].to(text_encoder.device)
#with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
#del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
@ -952,6 +938,10 @@ def main(args):
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch