update EveryDreamValidator for noprompt's changes
This commit is contained in:
parent
41c9f36ed7
commit
29396ec21b
|
@ -169,24 +169,6 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
return picked_images
|
return picked_images
|
||||||
|
|
||||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
|
||||||
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
|
|
||||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
|
||||||
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
|
|
||||||
random.shuffle(items_copy)
|
|
||||||
split_items = items_copy[:item_count]
|
|
||||||
if remove_from_dataset:
|
|
||||||
self.delete_items(split_items)
|
|
||||||
return split_items
|
|
||||||
|
|
||||||
def delete_items(self, items: list[ImageTrainItem]):
|
|
||||||
for item in items:
|
|
||||||
for i, other_item in enumerate(self.prepared_train_data):
|
|
||||||
if other_item.pathname == item.pathname:
|
|
||||||
self.prepared_train_data.pop(i)
|
|
||||||
break
|
|
||||||
self.__update_rating_sums()
|
|
||||||
|
|
||||||
def __update_rating_sums(self):
|
def __update_rating_sums(self):
|
||||||
self.rating_overall_sum: float = 0.0
|
self.rating_overall_sum: float = 0.0
|
||||||
self.ratings_summed: list[float] = []
|
self.ratings_summed: list[float] = []
|
||||||
|
|
|
@ -143,12 +143,12 @@ class EveryDreamBatch(Dataset):
|
||||||
def __update_image_train_items(self, dropout_fraction: float):
|
def __update_image_train_items(self, dropout_fraction: float):
|
||||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||||
|
|
||||||
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader:
|
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
items,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
collate_fn=collate_fn
|
collate_fn=collate_fn
|
||||||
)
|
)
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Callable, Any, Optional
|
from typing import Callable, Any, Optional
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
@ -14,55 +15,65 @@ from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
||||||
from data.data_loader import DataLoaderMultiAspect
|
from data.data_loader import DataLoaderMultiAspect
|
||||||
from data import resolver
|
from data import resolver
|
||||||
from data import aspects
|
from data import aspects
|
||||||
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.isolate_rng import isolate_rng
|
from utils.isolate_rng import isolate_rng
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||||
|
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||||
|
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
|
||||||
|
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||||
|
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
||||||
|
random.shuffle(items_copy)
|
||||||
|
split_items = list(items_copy[:split_item_count])
|
||||||
|
remaining_items = list(items_copy[split_item_count:])
|
||||||
|
return split_items, remaining_items
|
||||||
|
|
||||||
|
|
||||||
class EveryDreamValidator:
|
class EveryDreamValidator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
val_config_path: Optional[str],
|
val_config_path: Optional[str],
|
||||||
train_batch: EveryDreamBatch,
|
default_batch_size: int,
|
||||||
log_writer: SummaryWriter):
|
log_writer: SummaryWriter):
|
||||||
|
self.val_dataloader = None
|
||||||
|
self.train_overlapping_dataloader = None
|
||||||
|
|
||||||
self.log_writer = log_writer
|
self.log_writer = log_writer
|
||||||
|
|
||||||
val_config = {}
|
self.config = {}
|
||||||
if val_config_path is not None:
|
if val_config_path is not None:
|
||||||
with open(val_config_path, 'rt') as f:
|
with open(val_config_path, 'rt') as f:
|
||||||
val_config = json.load(f)
|
self.config = json.load(f)
|
||||||
|
|
||||||
do_validation = val_config.get('validate_training', False)
|
self.batch_size = self.config.get('batch_size', default_batch_size)
|
||||||
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none'
|
self.every_n_epochs = self.config.get('every_n_epochs', 1)
|
||||||
self.val_data_root = val_config.get('val_data_root', None)
|
self.seed = self.config.get('seed', 555)
|
||||||
val_split_proportion = val_config.get('val_split_proportion', 0.15)
|
self.val_data_root = self.config.get('val_data_root', None)
|
||||||
|
|
||||||
stabilize_training_loss = val_config.get('stabilize_training_loss', False)
|
|
||||||
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15)
|
|
||||||
|
|
||||||
self.every_n_epochs = val_config.get('every_n_epochs', 1)
|
|
||||||
self.seed = val_config.get('seed', 555)
|
|
||||||
|
|
||||||
|
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||||
|
"""
|
||||||
|
Build the validation splits as requested by the config passed at init.
|
||||||
|
This may steal some items from `train_items`.
|
||||||
|
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
|
||||||
|
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||||
|
"""
|
||||||
with isolate_rng():
|
with isolate_rng():
|
||||||
self.val_dataloader = self._build_validation_dataloader(val_split_mode,
|
self.val_dataloader, remaining_train_items = self._build_val_dataloader(train_items, tokenizer)
|
||||||
split_proportion=val_split_proportion,
|
|
||||||
val_data_root=self.val_data_root,
|
|
||||||
train_batch=train_batch)
|
|
||||||
# order is important - if we're removing images from train, this needs to happen before making
|
# order is important - if we're removing images from train, this needs to happen before making
|
||||||
# the overlapping dataloader
|
# the overlapping dataloader
|
||||||
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch,
|
self.train_overlapping_dataloader = self._build_train_stabiliser_dataloader(remaining_train_items, tokenizer)
|
||||||
split_proportion=stabilize_split_proportion,
|
return remaining_train_items
|
||||||
name='train-stabilizer',
|
|
||||||
enforce_split=False) if stabilize_training_loss else None
|
|
||||||
|
|
||||||
|
|
||||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||||
get_model_prediction_and_target_callable: Callable[
|
get_model_prediction_and_target_callable: Callable[
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
if (epoch % self.every_n_epochs) == 0:
|
if (epoch % self.every_n_epochs) == 0:
|
||||||
if self.train_overlapping_dataloader is not None:
|
if self.train_overlapping_dataloader is not None:
|
||||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable)
|
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
||||||
|
get_model_prediction_and_target_callable)
|
||||||
if self.val_dataloader is not None:
|
if self.val_dataloader is not None:
|
||||||
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||||
|
|
||||||
|
|
||||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
with torch.no_grad(), isolate_rng():
|
with torch.no_grad(), isolate_rng():
|
||||||
|
@ -75,8 +86,6 @@ class EveryDreamValidator:
|
||||||
torch.manual_seed(self.seed + step)
|
torch.manual_seed(self.seed + step)
|
||||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||||
|
|
||||||
# del timesteps, encoder_hidden_states, noisy_latents
|
|
||||||
# with autocast(enabled=args.amp):
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
@ -91,80 +100,55 @@ class EveryDreamValidator:
|
||||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||||
|
|
||||||
|
def _build_val_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||||
def _build_validation_dataloader(self,
|
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||||
validation_split_mode: str,
|
val_split_mode = self.config.get('val_split_mode', 'automatic')
|
||||||
split_proportion: float,
|
val_split_proportion = self.config.get('val_split_proportion', 0.15)
|
||||||
val_data_root: Optional[str],
|
remaining_train_items = image_train_items
|
||||||
train_batch: EveryDreamBatch) -> Optional[DataLoader]:
|
if val_split_mode == 'none':
|
||||||
if validation_split_mode == 'none':
|
return None, image_train_items
|
||||||
return None
|
elif val_split_mode == 'automatic':
|
||||||
elif validation_split_mode == 'automatic':
|
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||||
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True)
|
elif val_split_mode == 'manual':
|
||||||
elif validation_split_mode == 'manual':
|
|
||||||
if val_data_root is None:
|
|
||||||
raise ValueError("val_data_root is required for 'manual' validation split mode")
|
|
||||||
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'")
|
|
||||||
|
|
||||||
|
|
||||||
def _build_dataloader_from_automatic_split(self,
|
|
||||||
train_batch: EveryDreamBatch,
|
|
||||||
split_proportion: float,
|
|
||||||
name: str,
|
|
||||||
enforce_split: bool=False) -> DataLoader:
|
|
||||||
"""
|
|
||||||
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
|
|
||||||
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
|
|
||||||
"""
|
|
||||||
with isolate_rng():
|
|
||||||
random.seed(self.seed)
|
|
||||||
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
|
|
||||||
if enforce_split:
|
|
||||||
print(
|
|
||||||
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
|
|
||||||
if len(train_batch) == 0:
|
|
||||||
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
|
|
||||||
val_batch = self._make_val_batch_with_train_batch_settings(
|
|
||||||
val_items,
|
|
||||||
train_batch,
|
|
||||||
name=name
|
|
||||||
)
|
|
||||||
return build_torch_dataloader(
|
|
||||||
items=val_batch,
|
|
||||||
batch_size=train_batch.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader:
|
|
||||||
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch)
|
|
||||||
return build_torch_dataloader(
|
|
||||||
items=val_batch,
|
|
||||||
batch_size=reference_train_batch.batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
|
|
||||||
batch_size = reference_train_batch.batch_size
|
|
||||||
seed = reference_train_batch.seed
|
|
||||||
args = Namespace(
|
args = Namespace(
|
||||||
aspects=aspects.get_aspect_buckets(512),
|
aspects=aspects.get_aspect_buckets(512),
|
||||||
flip_p=0.0,
|
flip_p=0.0,
|
||||||
seed=seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
image_train_items = resolver.resolve(data_root, args)
|
val_items = resolver.resolve_root(self.val_data_root, args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||||
|
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||||
|
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||||
|
return val_dataloader, remaining_train_items
|
||||||
|
|
||||||
|
def _build_train_stabiliser_dataloader(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||||
|
-> Optional[torch.utils.data.DataLoader]:
|
||||||
|
stabilize_training_loss = self.config.get('stabilize_training_loss', False)
|
||||||
|
if not stabilize_training_loss:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stabilize_split_proportion = self.config.get('stabilize_split_proportion', 0.15)
|
||||||
|
stabilise_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||||
|
stabilize_ed_batch = self._build_ed_batch(stabilise_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||||
|
name='stabilize-train')
|
||||||
|
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||||
|
return stabilize_dataloader
|
||||||
|
|
||||||
|
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||||
|
batch_size = self.batch_size
|
||||||
|
seed = self.seed
|
||||||
data_loader = DataLoaderMultiAspect(
|
data_loader = DataLoaderMultiAspect(
|
||||||
image_train_items,
|
items,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
val_batch = EveryDreamBatch(
|
ed_batch = EveryDreamBatch(
|
||||||
data_loader=data_loader,
|
data_loader=data_loader,
|
||||||
debug_level=1,
|
debug_level=1,
|
||||||
batch_size=batch_size,
|
|
||||||
conditional_dropout=0,
|
conditional_dropout=0,
|
||||||
tokenizer=reference_train_batch.tokenizer,
|
tokenizer=tokenizer,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
return val_batch
|
return ed_batch
|
||||||
|
|
138
train.py
138
train.py
|
@ -52,7 +52,8 @@ import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from data.data_loader import DataLoaderMultiAspect
|
from data.data_loader import DataLoaderMultiAspect
|
||||||
|
|
||||||
from data.every_dream import EveryDreamBatch
|
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||||
|
from data.every_dream_validation import EveryDreamValidator
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.huggingface_downloader import try_download_model_from_hf
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
|
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
|
||||||
return sample_prompts
|
return sample_prompts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
|
||||||
"""
|
|
||||||
Collates batches
|
|
||||||
"""
|
|
||||||
images = [example["image"] for example in batch]
|
|
||||||
captions = [example["caption"] for example in batch]
|
|
||||||
tokens = [example["tokens"] for example in batch]
|
|
||||||
runt_size = batch[0]["runt_size"]
|
|
||||||
|
|
||||||
images = torch.stack(images)
|
|
||||||
images = images.to(memory_format=torch.contiguous_format).float()
|
|
||||||
|
|
||||||
ret = {
|
|
||||||
"tokens": torch.stack(tuple(tokens)),
|
|
||||||
"image": images,
|
|
||||||
"captions": captions,
|
|
||||||
"runt_size": runt_size,
|
|
||||||
}
|
|
||||||
del batch
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
"""
|
"""
|
||||||
Main entry point
|
Main entry point
|
||||||
|
@ -387,10 +365,13 @@ def main(args):
|
||||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||||
logging.info(f" Seed: {seed}")
|
logging.info(f" Seed: {seed}")
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
gpu = GPU()
|
gpu = GPU()
|
||||||
device = torch.device(f"cuda:{args.gpuid}")
|
device = torch.device(f"cuda:{args.gpuid}")
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
else:
|
||||||
|
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||||
|
|
||||||
|
@ -606,6 +587,11 @@ def main(args):
|
||||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
|
|
||||||
|
log_writer = SummaryWriter(log_dir=log_folder,
|
||||||
|
flush_secs=5,
|
||||||
|
comment="EveryDream2FineTunes",
|
||||||
|
)
|
||||||
|
|
||||||
betas = (0.9, 0.999)
|
betas = (0.9, 0.999)
|
||||||
epsilon = 1e-8
|
epsilon = 1e-8
|
||||||
if args.amp:
|
if args.amp:
|
||||||
|
@ -631,8 +617,13 @@ def main(args):
|
||||||
|
|
||||||
log_optimizer(optimizer, betas, epsilon)
|
log_optimizer(optimizer, betas, epsilon)
|
||||||
|
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
|
||||||
|
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||||
|
# the validation dataset may need to steal some items from image_train_items
|
||||||
|
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||||
|
|
||||||
data_loader = DataLoaderMultiAspect(
|
data_loader = DataLoaderMultiAspect(
|
||||||
image_train_items=image_train_items,
|
image_train_items=image_train_items,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -669,11 +660,6 @@ def main(args):
|
||||||
if args.wandb is not None and args.wandb:
|
if args.wandb is not None and args.wandb:
|
||||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||||
|
|
||||||
log_writer = SummaryWriter(
|
|
||||||
log_dir=log_folder,
|
|
||||||
flush_secs=5,
|
|
||||||
comment="EveryDream2FineTunes",
|
|
||||||
)
|
|
||||||
|
|
||||||
def log_args(log_writer, args):
|
def log_args(log_writer, args):
|
||||||
arglog = "args:\n"
|
arglog = "args:\n"
|
||||||
|
@ -729,15 +715,7 @@ def main(args):
|
||||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||||
|
|
||||||
|
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
|
||||||
train_batch,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=4,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
unet.train() if not args.disable_unet_training else unet.eval()
|
unet.train() if not args.disable_unet_training else unet.eval()
|
||||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||||
|
@ -776,6 +754,48 @@ def main(args):
|
||||||
|
|
||||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||||
|
|
||||||
|
# actual prediction function - shared between train and validate
|
||||||
|
def get_model_prediction_and_target(image, tokens):
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(enabled=args.amp):
|
||||||
|
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||||
|
latents = vae.encode(pixel_values, return_dict=False)
|
||||||
|
del pixel_values
|
||||||
|
latents = latents[0].sample() * 0.18215
|
||||||
|
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
cuda_caption = tokens.to(text_encoder.device)
|
||||||
|
|
||||||
|
# with autocast(enabled=args.amp):
|
||||||
|
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||||
|
|
||||||
|
if args.clip_skip > 0:
|
||||||
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
|
||||||
|
encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||||
|
else:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
del noise, latents, cuda_caption
|
||||||
|
|
||||||
|
with autocast(enabled=args.amp):
|
||||||
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
return model_pred, target
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||||
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
||||||
|
@ -809,41 +829,7 @@ def main(args):
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
with torch.no_grad():
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||||
with autocast(enabled=args.amp):
|
|
||||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
|
||||||
latents = vae.encode(pixel_values, return_dict=False)
|
|
||||||
del pixel_values
|
|
||||||
latents = latents[0].sample() * 0.18215
|
|
||||||
|
|
||||||
noise = torch.randn_like(latents)
|
|
||||||
bsz = latents.shape[0]
|
|
||||||
|
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
|
||||||
timesteps = timesteps.long()
|
|
||||||
|
|
||||||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
|
||||||
|
|
||||||
#with autocast(enabled=args.amp):
|
|
||||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
|
||||||
|
|
||||||
if args.clip_skip > 0:
|
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
|
|
||||||
else:
|
|
||||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
|
||||||
|
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
||||||
|
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
|
||||||
target = noise
|
|
||||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
||||||
del noise, latents, cuda_caption
|
|
||||||
|
|
||||||
with autocast(enabled=args.amp):
|
|
||||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
||||||
|
|
||||||
#del timesteps, encoder_hidden_states, noisy_latents
|
#del timesteps, encoder_hidden_states, noisy_latents
|
||||||
#with autocast(enabled=args.amp):
|
#with autocast(enabled=args.amp):
|
||||||
|
@ -952,6 +938,10 @@ def main(args):
|
||||||
|
|
||||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
|
||||||
|
# validate
|
||||||
|
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue