update EveryDreamValidator for noprompt's changes
This commit is contained in:
parent
41c9f36ed7
commit
29396ec21b
|
@ -169,24 +169,6 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return picked_images
|
||||
|
||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
|
||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
|
||||
random.shuffle(items_copy)
|
||||
split_items = items_copy[:item_count]
|
||||
if remove_from_dataset:
|
||||
self.delete_items(split_items)
|
||||
return split_items
|
||||
|
||||
def delete_items(self, items: list[ImageTrainItem]):
|
||||
for item in items:
|
||||
for i, other_item in enumerate(self.prepared_train_data):
|
||||
if other_item.pathname == item.pathname:
|
||||
self.prepared_train_data.pop(i)
|
||||
break
|
||||
self.__update_rating_sums()
|
||||
|
||||
def __update_rating_sums(self):
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
|
|
|
@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset):
|
|||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader:
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
items,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import math
|
||||
import random
|
||||
from typing import Callable, Any, Optional
|
||||
from argparse import Namespace
|
||||
|
@ -14,55 +15,65 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
|||
from data.data_loader import DataLoaderMultiAspect
|
||||
from data import resolver
|
||||
from data import aspects
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
|
||||
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
|
||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
||||
random.shuffle(items_copy)
|
||||
split_items = list(items_copy[:split_item_count])
|
||||
remaining_items = list(items_copy[split_item_count:])
|
||||
return split_items, remaining_items
|
||||
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
train_batch: EveryDreamBatch,
|
||||
default_batch_size: int,
|
||||
log_writer: SummaryWriter):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
|
||||
self.log_writer = log_writer
|
||||
|
||||
val_config = {}
|
||||
self.config = {}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
val_config = json.load(f)
|
||||
self.config = json.load(f)
|
||||
|
||||
do_validation = val_config.get('validate_training', False)
|
||||
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none'
|
||||
self.val_data_root = val_config.get('val_data_root', None)
|
||||
val_split_proportion = val_config.get('val_split_proportion', 0.15)
|
||||
|
||||
stabilize_training_loss = val_config.get('stabilize_training_loss', False)
|
||||
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15)
|
||||
|
||||
self.every_n_epochs = val_config.get('every_n_epochs', 1)
|
||||
self.seed = val_config.get('seed', 555)
|
||||
self.batch_size = self.config.get('batch_size', default_batch_size)
|
||||
self.every_n_epochs = self.config.get('every_n_epochs', 1)
|
||||
self.seed = self.config.get('seed', 555)
|
||||
self.val_data_root = self.config.get('val_data_root', None)
|
||||
|
||||
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Build the validation splits as requested by the config passed at init.
|
||||
This may steal some items from `train_items`.
|
||||
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
|
||||
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||
"""
|
||||
with isolate_rng():
|
||||
self.val_dataloader = self._build_validation_dataloader(val_split_mode,
|
||||
split_proportion=val_split_proportion,
|
||||
val_data_root=self.val_data_root,
|
||||
train_batch=train_batch)
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch,
|
||||
split_proportion=stabilize_split_proportion,
|
||||
name='train-stabilizer',
|
||||
enforce_split=False) if stabilize_training_loss else None
|
||||
|
||||
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer)
|
||||
return remaining_train_items
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable)
|
||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_dataloader is not None:
|
||||
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||
|
||||
|
||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
with torch.no_grad(), isolate_rng():
|
||||
|
@ -75,8 +86,6 @@ class EveryDreamValidator:
|
|||
torch.manual_seed(self.seed + step)
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
|
||||
# del timesteps, encoder_hidden_states, noisy_latents
|
||||
# with autocast(enabled=args.amp):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
@ -91,80 +100,55 @@ class EveryDreamValidator:
|
|||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||
|
||||
|
||||
def _build_validation_dataloader(self,
|
||||
validation_split_mode: str,
|
||||
split_proportion: float,
|
||||
val_data_root: Optional[str],
|
||||
train_batch: EveryDreamBatch) -> Optional[DataLoader]:
|
||||
if validation_split_mode == 'none':
|
||||
return None
|
||||
elif validation_split_mode == 'automatic':
|
||||
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True)
|
||||
elif validation_split_mode == 'manual':
|
||||
if val_data_root is None:
|
||||
raise ValueError("val_data_root is required for 'manual' validation split mode")
|
||||
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch)
|
||||
else:
|
||||
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'")
|
||||
|
||||
|
||||
def _build_dataloader_from_automatic_split(self,
|
||||
train_batch: EveryDreamBatch,
|
||||
split_proportion: float,
|
||||
name: str,
|
||||
enforce_split: bool=False) -> DataLoader:
|
||||
"""
|
||||
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
|
||||
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
|
||||
"""
|
||||
with isolate_rng():
|
||||
random.seed(self.seed)
|
||||
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
|
||||
if enforce_split:
|
||||
print(
|
||||
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
|
||||
if len(train_batch) == 0:
|
||||
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
|
||||
val_batch = self._make_val_batch_with_train_batch_settings(
|
||||
val_items,
|
||||
train_batch,
|
||||
name=name
|
||||
)
|
||||
return build_torch_dataloader(
|
||||
items=val_batch,
|
||||
batch_size=train_batch.batch_size,
|
||||
)
|
||||
|
||||
|
||||
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader:
|
||||
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch)
|
||||
return build_torch_dataloader(
|
||||
items=val_batch,
|
||||
batch_size=reference_train_batch.batch_size
|
||||
)
|
||||
|
||||
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
|
||||
batch_size = reference_train_batch.batch_size
|
||||
seed = reference_train_batch.seed
|
||||
def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config.get('val_split_mode', 'automatic')
|
||||
val_split_proportion = self.config.get('val_split_proportion', 0.15)
|
||||
remaining_train_items = image_train_items
|
||||
if val_split_mode == 'none':
|
||||
return None, image_train_items
|
||||
elif val_split_mode == 'automatic':
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||
elif val_split_mode == 'manual':
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(512),
|
||||
flip_p=0.0,
|
||||
seed=seed,
|
||||
seed=self.seed,
|
||||
)
|
||||
image_train_items = resolver.resolve(data_root, args)
|
||||
val_items = resolver.resolve_root(self.val_data_root, args)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return val_dataloader, remaining_train_items
|
||||
|
||||
def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> Optional[torch.utils.data.DataLoader]:
|
||||
stabilize_training_loss = self.config.get('stabilize_training_loss', False)
|
||||
if not stabilize_training_loss:
|
||||
return None
|
||||
|
||||
stabilize_split_proportion = self.config.get('stabilize_split_proportion', 0.15)
|
||||
stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||
name='stabilize-train')
|
||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
return stabilize_dataloader
|
||||
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||
batch_size = self.batch_size
|
||||
seed = self.seed
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items,
|
||||
items,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
)
|
||||
val_batch = EveryDreamBatch(
|
||||
ed_batch = EveryDreamBatch(
|
||||
data_loader=data_loader,
|
||||
debug_level=1,
|
||||
batch_size=batch_size,
|
||||
conditional_dropout=0,
|
||||
tokenizer=reference_train_batch.tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
seed=seed,
|
||||
name=name,
|
||||
)
|
||||
return val_batch
|
||||
return ed_batch
|
||||
|
|
138
train.py
138
train.py
|
@ -52,7 +52,8 @@ import wandb
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
||||
from data.every_dream import EveryDreamBatch
|
||||
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||
from data.every_dream_validation import EveryDreamValidator
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
|
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
|
|||
return sample_prompts
|
||||
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -387,10 +365,13 @@ def main(args):
|
|||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||
device = 'cpu'
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
|
||||
|
@ -606,6 +587,11 @@ def main(args):
|
|||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
|
@ -631,8 +617,13 @@ def main(args):
|
|||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||
# the validation dataset may need to steal some items from image_train_items
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
|
@ -669,11 +660,6 @@ def main(args):
|
|||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
|
||||
log_writer = SummaryWriter(
|
||||
log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
def log_args(log_writer, args):
|
||||
arglog = "args:\n"
|
||||
|
@ -729,15 +715,7 @@ def main(args):
|
|||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_batch,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True
|
||||
)
|
||||
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
||||
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||
|
@ -776,6 +754,48 @@ def main(args):
|
|||
|
||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
cuda_caption = tokens.to(text_encoder.device)
|
||||
|
||||
# with autocast(enabled=args.amp):
|
||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||
|
||||
if args.clip_skip > 0:
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
|
||||
encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
return model_pred, target
|
||||
|
||||
|
||||
try:
|
||||
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
||||
|
@ -809,41 +829,7 @@ def main(args):
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
||||
|
||||
#with autocast(enabled=args.amp):
|
||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||
|
||||
if args.clip_skip > 0:
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
|
||||
#del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
|
@ -952,6 +938,10 @@ def main(args):
|
|||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
# validate
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
|
Loading…
Reference in New Issue