Merge pull request #1 from damian0815/add-support-for-val-split
update & cleanup EveryDreamValidator
This commit is contained in:
commit
fc79413224
|
@ -169,24 +169,6 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
return picked_images
|
return picked_images
|
||||||
|
|
||||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
|
||||||
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
|
|
||||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
|
||||||
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
|
|
||||||
random.shuffle(items_copy)
|
|
||||||
split_items = items_copy[:item_count]
|
|
||||||
if remove_from_dataset:
|
|
||||||
self.delete_items(split_items)
|
|
||||||
return split_items
|
|
||||||
|
|
||||||
def delete_items(self, items: list[ImageTrainItem]):
|
|
||||||
for item in items:
|
|
||||||
for i, other_item in enumerate(self.prepared_train_data):
|
|
||||||
if other_item.pathname == item.pathname:
|
|
||||||
self.prepared_train_data.pop(i)
|
|
||||||
break
|
|
||||||
self.__update_rating_sums()
|
|
||||||
|
|
||||||
def __update_rating_sums(self):
|
def __update_rating_sums(self):
|
||||||
self.rating_overall_sum: float = 0.0
|
self.rating_overall_sum: float = 0.0
|
||||||
self.ratings_summed: list[float] = []
|
self.ratings_summed: list[float] = []
|
||||||
|
|
|
@ -63,7 +63,7 @@ class EveryDreamBatch(Dataset):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
num_images = len(self.image_train_items)
|
num_images = len(self.image_train_items)
|
||||||
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||||
|
|
||||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||||
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
||||||
|
@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset):
|
||||||
def __update_image_train_items(self, dropout_fraction: float):
|
def __update_image_train_items(self, dropout_fraction: float):
|
||||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||||
|
|
||||||
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader:
|
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
items,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
collate_fn=collate_fn
|
collate_fn=collate_fn
|
||||||
)
|
)
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Callable, Any, Optional
|
from typing import Callable, Any, Optional
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
@ -14,57 +15,85 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
||||||
from data.data_loader import DataLoaderMultiAspect
|
from data.data_loader import DataLoaderMultiAspect
|
||||||
from data import resolver
|
from data import resolver
|
||||||
from data import aspects
|
from data import aspects
|
||||||
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.isolate_rng import isolate_rng
|
from utils.isolate_rng import isolate_rng
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||||
|
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||||
|
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
|
||||||
|
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||||
|
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
||||||
|
random.shuffle(items_copy)
|
||||||
|
split_items = list(items_copy[:split_item_count])
|
||||||
|
remaining_items = list(items_copy[split_item_count:])
|
||||||
|
return split_items, remaining_items
|
||||||
|
|
||||||
|
|
||||||
class EveryDreamValidator:
|
class EveryDreamValidator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
val_config_path: Optional[str],
|
val_config_path: Optional[str],
|
||||||
train_batch: EveryDreamBatch,
|
default_batch_size: int,
|
||||||
log_writer: SummaryWriter):
|
log_writer: SummaryWriter):
|
||||||
|
self.val_dataloader = None
|
||||||
|
self.train_overlapping_dataloader = None
|
||||||
|
|
||||||
self.log_writer = log_writer
|
self.log_writer = log_writer
|
||||||
|
|
||||||
val_config = {}
|
self.config = {
|
||||||
|
'batch_size': default_batch_size,
|
||||||
|
'every_n_epochs': 1,
|
||||||
|
'seed': 555,
|
||||||
|
|
||||||
|
'val_split_mode': 'automatic',
|
||||||
|
'val_split_proportion': 0.15,
|
||||||
|
|
||||||
|
'stabilize_training_loss': False,
|
||||||
|
'stabilize_split_proportion': 0.15
|
||||||
|
}
|
||||||
if val_config_path is not None:
|
if val_config_path is not None:
|
||||||
with open(val_config_path, 'rt') as f:
|
with open(val_config_path, 'rt') as f:
|
||||||
val_config = json.load(f)
|
self.config.update(json.load(f))
|
||||||
|
|
||||||
do_validation = val_config.get('validate_training', False)
|
@property
|
||||||
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none'
|
def batch_size(self):
|
||||||
self.val_data_root = val_config.get('val_data_root', None)
|
return self.config['batch_size']
|
||||||
val_split_proportion = val_config.get('val_split_proportion', 0.15)
|
|
||||||
|
|
||||||
stabilize_training_loss = val_config.get('stabilize_training_loss', False)
|
@property
|
||||||
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15)
|
def every_n_epochs(self):
|
||||||
|
return self.config['every_n_epochs']
|
||||||
|
|
||||||
self.every_n_epochs = val_config.get('every_n_epochs', 1)
|
@property
|
||||||
self.seed = val_config.get('seed', 555)
|
def seed(self):
|
||||||
|
return self.config['seed']
|
||||||
|
|
||||||
|
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||||
|
"""
|
||||||
|
Build the validation splits as requested by the config passed at init.
|
||||||
|
This may steal some items from `train_items`.
|
||||||
|
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
|
||||||
|
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||||
|
"""
|
||||||
with isolate_rng():
|
with isolate_rng():
|
||||||
self.val_dataloader = self._build_validation_dataloader(val_split_mode,
|
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||||
split_proportion=val_split_proportion,
|
|
||||||
val_data_root=self.val_data_root,
|
|
||||||
train_batch=train_batch)
|
|
||||||
# order is important - if we're removing images from train, this needs to happen before making
|
# order is important - if we're removing images from train, this needs to happen before making
|
||||||
# the overlapping dataloader
|
# the overlapping dataloader
|
||||||
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch,
|
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
|
||||||
split_proportion=stabilize_split_proportion,
|
remaining_train_items, tokenizer)
|
||||||
name='train-stabilizer',
|
return remaining_train_items
|
||||||
enforce_split=False) if stabilize_training_loss else None
|
|
||||||
|
|
||||||
|
|
||||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||||
get_model_prediction_and_target_callable: Callable[
|
get_model_prediction_and_target_callable: Callable[
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
if (epoch % self.every_n_epochs) == 0:
|
if (epoch % self.every_n_epochs) == 0:
|
||||||
if self.train_overlapping_dataloader is not None:
|
if self.train_overlapping_dataloader is not None:
|
||||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable)
|
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
||||||
|
get_model_prediction_and_target_callable)
|
||||||
if self.val_dataloader is not None:
|
if self.val_dataloader is not None:
|
||||||
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||||
|
|
||||||
|
|
||||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
with torch.no_grad(), isolate_rng():
|
with torch.no_grad(), isolate_rng():
|
||||||
loss_validation_epoch = []
|
loss_validation_epoch = []
|
||||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
||||||
|
@ -75,8 +104,6 @@ class EveryDreamValidator:
|
||||||
torch.manual_seed(self.seed + step)
|
torch.manual_seed(self.seed + step)
|
||||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||||
|
|
||||||
# del timesteps, encoder_hidden_states, noisy_latents
|
|
||||||
# with autocast(enabled=args.amp):
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
@ -91,80 +118,56 @@ class EveryDreamValidator:
|
||||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||||
|
|
||||||
|
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||||
def _build_validation_dataloader(self,
|
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||||
validation_split_mode: str,
|
val_split_mode = self.config['val_split_mode']
|
||||||
split_proportion: float,
|
val_split_proportion = self.config['val_split_proportion']
|
||||||
val_data_root: Optional[str],
|
remaining_train_items = image_train_items
|
||||||
train_batch: EveryDreamBatch) -> Optional[DataLoader]:
|
if val_split_mode == 'none':
|
||||||
if validation_split_mode == 'none':
|
return None, image_train_items
|
||||||
return None
|
elif val_split_mode == 'automatic':
|
||||||
elif validation_split_mode == 'automatic':
|
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||||
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True)
|
elif val_split_mode == 'manual':
|
||||||
elif validation_split_mode == 'manual':
|
args = Namespace(
|
||||||
if val_data_root is None:
|
aspects=aspects.get_aspect_buckets(512),
|
||||||
raise ValueError("val_data_root is required for 'manual' validation split mode")
|
flip_p=0.0,
|
||||||
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch)
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
val_data_root = self.config['val_data_root']
|
||||||
|
val_items = resolver.resolve_root(val_data_root, args)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'")
|
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||||
|
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||||
|
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||||
|
return val_dataloader, remaining_train_items
|
||||||
|
|
||||||
|
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||||
|
-> Optional[torch.utils.data.DataLoader]:
|
||||||
|
stabilize_training_loss = self.config['stabilize_training_loss']
|
||||||
|
if not stabilize_training_loss:
|
||||||
|
return None
|
||||||
|
|
||||||
def _build_dataloader_from_automatic_split(self,
|
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||||
train_batch: EveryDreamBatch,
|
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||||
split_proportion: float,
|
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||||
name: str,
|
name='stabilize-train')
|
||||||
enforce_split: bool=False) -> DataLoader:
|
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||||
"""
|
return stabilize_dataloader
|
||||||
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
|
|
||||||
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
|
|
||||||
"""
|
|
||||||
with isolate_rng():
|
|
||||||
random.seed(self.seed)
|
|
||||||
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
|
|
||||||
if enforce_split:
|
|
||||||
print(
|
|
||||||
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
|
|
||||||
if len(train_batch) == 0:
|
|
||||||
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
|
|
||||||
val_batch = self._make_val_batch_with_train_batch_settings(
|
|
||||||
val_items,
|
|
||||||
train_batch,
|
|
||||||
name=name
|
|
||||||
)
|
|
||||||
return build_torch_dataloader(
|
|
||||||
items=val_batch,
|
|
||||||
batch_size=train_batch.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||||
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader:
|
batch_size = self.batch_size
|
||||||
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch)
|
seed = self.seed
|
||||||
return build_torch_dataloader(
|
|
||||||
items=val_batch,
|
|
||||||
batch_size=reference_train_batch.batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
|
|
||||||
batch_size = reference_train_batch.batch_size
|
|
||||||
seed = reference_train_batch.seed
|
|
||||||
args = Namespace(
|
|
||||||
aspects=aspects.get_aspect_buckets(512),
|
|
||||||
flip_p=0.0,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
image_train_items = resolver.resolve(data_root, args)
|
|
||||||
data_loader = DataLoaderMultiAspect(
|
data_loader = DataLoaderMultiAspect(
|
||||||
image_train_items,
|
items,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
val_batch = EveryDreamBatch(
|
ed_batch = EveryDreamBatch(
|
||||||
data_loader=data_loader,
|
data_loader=data_loader,
|
||||||
debug_level=1,
|
debug_level=1,
|
||||||
batch_size=batch_size,
|
|
||||||
conditional_dropout=0,
|
conditional_dropout=0,
|
||||||
tokenizer=reference_train_batch.tokenizer,
|
tokenizer=tokenizer,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
return val_batch
|
return ed_batch
|
||||||
|
|
144
train.py
144
train.py
|
@ -52,7 +52,8 @@ import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from data.data_loader import DataLoaderMultiAspect
|
from data.data_loader import DataLoaderMultiAspect
|
||||||
|
|
||||||
from data.every_dream import EveryDreamBatch
|
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||||
|
from data.every_dream_validation import EveryDreamValidator
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.huggingface_downloader import try_download_model_from_hf
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
|
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
|
||||||
return sample_prompts
|
return sample_prompts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
|
||||||
"""
|
|
||||||
Collates batches
|
|
||||||
"""
|
|
||||||
images = [example["image"] for example in batch]
|
|
||||||
captions = [example["caption"] for example in batch]
|
|
||||||
tokens = [example["tokens"] for example in batch]
|
|
||||||
runt_size = batch[0]["runt_size"]
|
|
||||||
|
|
||||||
images = torch.stack(images)
|
|
||||||
images = images.to(memory_format=torch.contiguous_format).float()
|
|
||||||
|
|
||||||
ret = {
|
|
||||||
"tokens": torch.stack(tuple(tokens)),
|
|
||||||
"image": images,
|
|
||||||
"captions": captions,
|
|
||||||
"runt_size": runt_size,
|
|
||||||
}
|
|
||||||
del batch
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
"""
|
"""
|
||||||
Main entry point
|
Main entry point
|
||||||
|
@ -387,10 +365,13 @@ def main(args):
|
||||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||||
logging.info(f" Seed: {seed}")
|
logging.info(f" Seed: {seed}")
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
gpu = GPU()
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{args.gpuid}")
|
gpu = GPU()
|
||||||
|
device = torch.device(f"cuda:{args.gpuid}")
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
else:
|
||||||
|
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||||
|
|
||||||
|
@ -606,6 +587,11 @@ def main(args):
|
||||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
|
|
||||||
|
log_writer = SummaryWriter(log_dir=log_folder,
|
||||||
|
flush_secs=5,
|
||||||
|
comment="EveryDream2FineTunes",
|
||||||
|
)
|
||||||
|
|
||||||
betas = (0.9, 0.999)
|
betas = (0.9, 0.999)
|
||||||
epsilon = 1e-8
|
epsilon = 1e-8
|
||||||
if args.amp:
|
if args.amp:
|
||||||
|
@ -631,8 +617,13 @@ def main(args):
|
||||||
|
|
||||||
log_optimizer(optimizer, betas, epsilon)
|
log_optimizer(optimizer, betas, epsilon)
|
||||||
|
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
|
||||||
|
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||||
|
# the validation dataset may need to steal some items from image_train_items
|
||||||
|
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||||
|
|
||||||
data_loader = DataLoaderMultiAspect(
|
data_loader = DataLoaderMultiAspect(
|
||||||
image_train_items=image_train_items,
|
image_train_items=image_train_items,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -669,11 +660,6 @@ def main(args):
|
||||||
if args.wandb is not None and args.wandb:
|
if args.wandb is not None and args.wandb:
|
||||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||||
|
|
||||||
log_writer = SummaryWriter(
|
|
||||||
log_dir=log_folder,
|
|
||||||
flush_secs=5,
|
|
||||||
comment="EveryDream2FineTunes",
|
|
||||||
)
|
|
||||||
|
|
||||||
def log_args(log_writer, args):
|
def log_args(log_writer, args):
|
||||||
arglog = "args:\n"
|
arglog = "args:\n"
|
||||||
|
@ -729,15 +715,7 @@ def main(args):
|
||||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||||
|
|
||||||
|
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
|
||||||
train_batch,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=4,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
unet.train() if not args.disable_unet_training else unet.eval()
|
unet.train() if not args.disable_unet_training else unet.eval()
|
||||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||||
|
@ -776,6 +754,48 @@ def main(args):
|
||||||
|
|
||||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||||
|
|
||||||
|
# actual prediction function - shared between train and validate
|
||||||
|
def get_model_prediction_and_target(image, tokens):
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(enabled=args.amp):
|
||||||
|
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||||
|
latents = vae.encode(pixel_values, return_dict=False)
|
||||||
|
del pixel_values
|
||||||
|
latents = latents[0].sample() * 0.18215
|
||||||
|
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
cuda_caption = tokens.to(text_encoder.device)
|
||||||
|
|
||||||
|
# with autocast(enabled=args.amp):
|
||||||
|
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||||
|
|
||||||
|
if args.clip_skip > 0:
|
||||||
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
|
||||||
|
encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||||
|
else:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
del noise, latents, cuda_caption
|
||||||
|
|
||||||
|
with autocast(enabled=args.amp):
|
||||||
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
return model_pred, target
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||||
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
||||||
|
@ -809,41 +829,7 @@ def main(args):
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
with torch.no_grad():
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||||
with autocast(enabled=args.amp):
|
|
||||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
|
||||||
latents = vae.encode(pixel_values, return_dict=False)
|
|
||||||
del pixel_values
|
|
||||||
latents = latents[0].sample() * 0.18215
|
|
||||||
|
|
||||||
noise = torch.randn_like(latents)
|
|
||||||
bsz = latents.shape[0]
|
|
||||||
|
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
||||||
timesteps = timesteps.long()
|
|
||||||
|
|
||||||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
|
||||||
|
|
||||||
#with autocast(enabled=args.amp):
|
|
||||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
|
||||||
|
|
||||||
if args.clip_skip > 0:
|
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
|
|
||||||
else:
|
|
||||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
|
||||||
|
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
||||||
|
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
|
||||||
target = noise
|
|
||||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
||||||
del noise, latents, cuda_caption
|
|
||||||
|
|
||||||
with autocast(enabled=args.amp):
|
|
||||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
||||||
|
|
||||||
#del timesteps, encoder_hidden_states, noisy_latents
|
#del timesteps, encoder_hidden_states, noisy_latents
|
||||||
#with autocast(enabled=args.amp):
|
#with autocast(enabled=args.amp):
|
||||||
|
@ -952,6 +938,10 @@ def main(args):
|
||||||
|
|
||||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
|
||||||
|
# validate
|
||||||
|
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||||
"stabilize_training_loss": "If true, stabilise the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||||
"seed": "The seed to use when running validation and stabilization passes."
|
"seed": "The seed to use when running validation and stabilization passes."
|
||||||
|
@ -13,7 +13,7 @@
|
||||||
"val_split_mode": "automatic",
|
"val_split_mode": "automatic",
|
||||||
"val_data_root": null,
|
"val_data_root": null,
|
||||||
"val_split_proportion": 0.15,
|
"val_split_proportion": 0.15,
|
||||||
"stabilize_training_loss": true,
|
"stabilize_training_loss": false,
|
||||||
"stabilize_split_proportion": 0.15,
|
"stabilize_split_proportion": 0.15,
|
||||||
"every_n_epochs": 1,
|
"every_n_epochs": 1,
|
||||||
"seed": 555
|
"seed": 555
|
||||||
|
|
Loading…
Reference in New Issue