Merge pull request #50 from noprompt/add-support-for-val-split
Add support for validation split
This commit is contained in:
commit
165525a71c
|
@ -14,11 +14,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
import bisect
|
||||
from functools import reduce
|
||||
import math
|
||||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
import PIL
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
@ -27,23 +28,19 @@ class DataLoaderMultiAspect():
|
|||
"""
|
||||
Data loader for multi-aspect-ratio training and bucketing
|
||||
|
||||
image_train_items: list of `lImageTrainItem` objects
|
||||
image_train_items: list of `ImageTrainItem` objects
|
||||
seed: random seed
|
||||
batch_size: number of images per batch
|
||||
"""
|
||||
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
# Prepare data
|
||||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
# Initialize ratings
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
for image in self.prepared_train_data:
|
||||
self.rating_overall_sum += image.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
self.__update_rating_sums()
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
"""
|
||||
|
@ -80,14 +77,17 @@ class DataLoaderMultiAspect():
|
|||
del data_copy
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
||||
"""
|
||||
returns the current list of images including their captions in a randomized order,
|
||||
sorted into buckets with same sized images
|
||||
if dropout_fraction < 1.0, only a subset of the images will be returned
|
||||
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
|
||||
Returns the current list of `ImageTrainItem` in randomized order,
|
||||
sorted into buckets with same sized images.
|
||||
|
||||
If dropout_fraction < 1.0, only a subset of the images will be returned.
|
||||
|
||||
If dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan.
|
||||
|
||||
:param dropout_fraction: must be between 0.0 and 1.0.
|
||||
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
||||
:return: Randomized list of `ImageTrainItem` objects
|
||||
"""
|
||||
|
||||
self.seed += 1
|
||||
|
@ -126,11 +126,11 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
image_caption_pairs = []
|
||||
items: list[ImageTrainItem] = []
|
||||
for bucket in buckets:
|
||||
image_caption_pairs.extend(buckets[bucket])
|
||||
items.extend(buckets[bucket])
|
||||
|
||||
return image_caption_pairs
|
||||
return items
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -168,3 +168,10 @@ class DataLoaderMultiAspect():
|
|||
prepared_train_data.pop(pos)
|
||||
|
||||
return picked_images
|
||||
|
||||
def __update_rating_sums(self):
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
for item in self.prepared_train_data:
|
||||
self.rating_overall_sum += item.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
|
@ -1,2 +0,0 @@
|
|||
# stop lightning's repeated instantiation of batch train/val/test classes causing multiple sweeps of the same data off disk
|
||||
shared_dataloader = None
|
|
@ -1,70 +0,0 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
|
||||
import math
|
||||
import ldm.data.dl_singleton as dls
|
||||
from ldm.data.image_train_item import ImageTrainItem
|
||||
|
||||
class EDValidateBatch(Dataset):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
flip_p=0.0,
|
||||
repeats=1,
|
||||
debug_level=0,
|
||||
batch_size=1,
|
||||
set='val',
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
||||
if not dls.shared_dataloader:
|
||||
print("Creating new dataloader singleton")
|
||||
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p)
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
|
||||
self.num_images = len(self.image_train_items)
|
||||
|
||||
self._length = max(math.trunc(self.num_images * repeats), batch_size) - self.num_images % self.batch_size
|
||||
|
||||
print()
|
||||
print(f" ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} ")
|
||||
print()
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = i % self.num_images
|
||||
image_train_item = self.image_train_items[idx]
|
||||
|
||||
example = self.__get_image_for_trainer(image_train_item)
|
||||
return example
|
||||
|
||||
@staticmethod
|
||||
def __get_image_for_trainer(image_train_item: ImageTrainItem):
|
||||
example = {}
|
||||
|
||||
image_train_tmp = image_train_item.hydrate()
|
||||
|
||||
example["image"] = image_train_tmp.image
|
||||
example["caption"] = image_train_tmp.caption
|
||||
|
||||
return example
|
||||
|
|
@ -41,7 +41,8 @@ class EveryDreamBatch(Dataset):
|
|||
retain_contrast=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5
|
||||
rated_dataset_dropout_target=0.5,
|
||||
name='train'
|
||||
):
|
||||
self.data_loader = data_loader
|
||||
self.batch_size = data_loader.batch_size
|
||||
|
@ -57,10 +58,18 @@ class EveryDreamBatch(Dataset):
|
|||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
# First epoch always trains on all images
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
|
||||
self.image_train_items = []
|
||||
self.__update_image_train_items(1.0)
|
||||
self.name = name
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
|
||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
||||
self.__update_image_train_items(1.0)
|
||||
return items
|
||||
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
@ -70,7 +79,7 @@ class EveryDreamBatch(Dataset):
|
|||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
self.__update_image_train_items(dropout_fraction)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_train_items)
|
||||
|
@ -130,3 +139,38 @@ class EveryDreamBatch(Dataset):
|
|||
example["caption"] = image_train_tmp.caption
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
return example
|
||||
|
||||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
import json
|
||||
import math
|
||||
import random
|
||||
from typing import Callable, Any, Optional
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from colorama import Fore, Style
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
from data import resolver
|
||||
from data import aspects
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
|
||||
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
|
||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
||||
random.shuffle(items_copy)
|
||||
split_items = list(items_copy[:split_item_count])
|
||||
remaining_items = list(items_copy[split_item_count:])
|
||||
return split_items, remaining_items
|
||||
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
default_batch_size: int,
|
||||
log_writer: SummaryWriter):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
|
||||
self.log_writer = log_writer
|
||||
|
||||
self.config = {
|
||||
'batch_size': default_batch_size,
|
||||
'every_n_epochs': 1,
|
||||
'seed': 555,
|
||||
|
||||
'val_split_mode': 'automatic',
|
||||
'val_split_proportion': 0.15,
|
||||
|
||||
'stabilize_training_loss': False,
|
||||
'stabilize_split_proportion': 0.15
|
||||
}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
self.config.update(json.load(f))
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.config['batch_size']
|
||||
|
||||
@property
|
||||
def every_n_epochs(self):
|
||||
return self.config['every_n_epochs']
|
||||
|
||||
@property
|
||||
def seed(self):
|
||||
return self.config['seed']
|
||||
|
||||
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Build the validation splits as requested by the config passed at init.
|
||||
This may steal some items from `train_items`.
|
||||
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
|
||||
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
||||
"""
|
||||
with isolate_rng():
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
|
||||
remaining_train_items, tokenizer)
|
||||
return remaining_train_items
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_dataloader is not None:
|
||||
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||
|
||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
with torch.no_grad(), isolate_rng():
|
||||
loss_validation_epoch = []
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
# ok to override seed here because we are in a `with isolate_rng():` block
|
||||
torch.manual_seed(self.seed + step)
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
loss_validation_epoch.append(loss_step)
|
||||
|
||||
steps_pbar.update(1)
|
||||
|
||||
steps_pbar.close()
|
||||
|
||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config['val_split_mode']
|
||||
val_split_proportion = self.config['val_split_proportion']
|
||||
remaining_train_items = image_train_items
|
||||
if val_split_mode == 'none':
|
||||
return None, image_train_items
|
||||
elif val_split_mode == 'automatic':
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||
elif val_split_mode == 'manual':
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(512),
|
||||
flip_p=0.0,
|
||||
seed=self.seed,
|
||||
)
|
||||
val_data_root = self.config['val_data_root']
|
||||
val_items = resolver.resolve_root(val_data_root, args)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return val_dataloader, remaining_train_items
|
||||
|
||||
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> Optional[torch.utils.data.DataLoader]:
|
||||
stabilize_training_loss = self.config['stabilize_training_loss']
|
||||
if not stabilize_training_loss:
|
||||
return None
|
||||
|
||||
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||
name='stabilize-train')
|
||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
return stabilize_dataloader
|
||||
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||
batch_size = self.batch_size
|
||||
seed = self.seed
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
items,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
)
|
||||
ed_batch = EveryDreamBatch(
|
||||
data_loader=data_loader,
|
||||
debug_level=1,
|
||||
conditional_dropout=0,
|
||||
tokenizer=tokenizer,
|
||||
seed=seed,
|
||||
name=name,
|
||||
)
|
||||
return ed_batch
|
138
train.py
138
train.py
|
@ -52,7 +52,8 @@ import wandb
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
||||
from data.every_dream import EveryDreamBatch
|
||||
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||
from data.every_dream_validation import EveryDreamValidator
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
|
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
|
|||
return sample_prompts
|
||||
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -387,10 +365,13 @@ def main(args):
|
|||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||
device = 'cpu'
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
|
||||
|
@ -606,6 +587,11 @@ def main(args):
|
|||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
|
@ -631,8 +617,13 @@ def main(args):
|
|||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||
# the validation dataset may need to steal some items from image_train_items
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
|
@ -669,11 +660,6 @@ def main(args):
|
|||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
|
||||
log_writer = SummaryWriter(
|
||||
log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
def log_args(log_writer, args):
|
||||
arglog = "args:\n"
|
||||
|
@ -729,15 +715,7 @@ def main(args):
|
|||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_batch,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True
|
||||
)
|
||||
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
|
||||
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||
|
@ -776,6 +754,48 @@ def main(args):
|
|||
|
||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
cuda_caption = tokens.to(text_encoder.device)
|
||||
|
||||
# with autocast(enabled=args.amp):
|
||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||
|
||||
if args.clip_skip > 0:
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
|
||||
encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
return model_pred, target
|
||||
|
||||
|
||||
try:
|
||||
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
||||
|
@ -809,41 +829,7 @@ def main(args):
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
||||
|
||||
#with autocast(enabled=args.amp):
|
||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||
|
||||
if args.clip_skip > 0:
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
|
||||
#del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
|
@ -952,6 +938,10 @@ def main(args):
|
|||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
# validate
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# copy/pasted from pytorch lightning
|
||||
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
|
||||
# and
|
||||
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
|
||||
|
||||
# Copyright The Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Dict, Any
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import getstate as python_get_rng_state
|
||||
from random import setstate as python_set_rng_state
|
||||
|
||||
|
||||
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
||||
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
||||
states = {
|
||||
"torch": torch.get_rng_state(),
|
||||
"numpy": np.random.get_state(),
|
||||
"python": python_get_rng_state(),
|
||||
}
|
||||
if include_cuda:
|
||||
states["torch.cuda"] = torch.cuda.get_rng_state_all()
|
||||
return states
|
||||
|
||||
|
||||
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
||||
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
||||
process."""
|
||||
torch.set_rng_state(rng_state_dict["torch"])
|
||||
# torch.cuda rng_state is only included since v1.8.
|
||||
if "torch.cuda" in rng_state_dict:
|
||||
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
|
||||
np.random.set_state(rng_state_dict["numpy"])
|
||||
version, state, gauss = rng_state_dict["python"]
|
||||
python_set_rng_state((version, tuple(state), gauss))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
|
||||
"""A context manager that resets the global random state on exit to what it was before entering.
|
||||
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
|
||||
Args:
|
||||
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
|
||||
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
|
||||
prohibited.
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> torch.manual_seed(1) # doctest: +ELLIPSIS
|
||||
<torch._C.Generator object at ...>
|
||||
>>> with isolate_rng():
|
||||
... [torch.rand(1) for _ in range(3)]
|
||||
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
|
||||
>>> torch.rand(1)
|
||||
tensor([0.7576])
|
||||
"""
|
||||
states = _collect_rng_states(include_cuda)
|
||||
yield
|
||||
_set_rng_states(states)
|
|
@ -0,0 +1,65 @@
|
|||
import argparse
|
||||
import math
|
||||
import os.path
|
||||
import random
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
|
||||
|
||||
|
||||
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
|
||||
for directory, _, filenames in os.walk(root_dir):
|
||||
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
|
||||
for image_filename in image_filenames:
|
||||
caption_filename = os.path.splitext(image_filename)[0] + '.txt'
|
||||
image_path = os.path.join(directory+image_filename)
|
||||
caption_path = os.path.join(directory+caption_filename)
|
||||
yield (image_path, caption_path if os.path.exists(caption_path) else None)
|
||||
|
||||
|
||||
def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
|
||||
image_path = image_caption_pair[0]
|
||||
caption_path = image_caption_pair[1]
|
||||
|
||||
# make target folder if necessary
|
||||
relative_folder = os.path.dirname(os.path.relpath(image_path, source_root))
|
||||
target_folder = os.path.join(target_root, relative_folder)
|
||||
os.makedirs(target_folder, exist_ok=True)
|
||||
|
||||
# copy files
|
||||
shutil.copy2(image_path, os.path.join(target_folder, os.path.basename(image_path)))
|
||||
if caption_path is not None:
|
||||
shutil.copy2(caption_path, os.path.join(target_folder, os.path.basename(caption_path)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('source_root', type=str, help='Source root folder containing images')
|
||||
parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset")
|
||||
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset")
|
||||
parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)")
|
||||
parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling')
|
||||
args = parser.parse_args()
|
||||
|
||||
images = gather_captioned_images(args.source_root)
|
||||
print(f"Found {len(images)} captioned images in {args.source_root}")
|
||||
val_split_count = math.ceil(len(images) * args.split_proportion)
|
||||
if val_split_count == 0:
|
||||
raise ValueError(f"No images in validation split with source count {len(images)} and split proportion {args.split_proportion}")
|
||||
|
||||
random.seed(args.seed)
|
||||
random.shuffle(images)
|
||||
val_split = images[0:val_split_count]
|
||||
train_split = images[val_split_count:]
|
||||
print(f"Split to 'train' set with {len(train_split)} images and 'val' set with {len(val_split)}")
|
||||
print(f"Copying 'val' set to {args.val_output_folder}...")
|
||||
for v in tqdm(val_split):
|
||||
copy_captioned_image(v, args.source_root, args.val_output_folder)
|
||||
print(f"Copying 'train' set to {args.train_output_folder}...")
|
||||
for v in tqdm(train_split):
|
||||
copy_captioned_image(v, args.source_root, args.train_output_folder)
|
||||
print("Done.")
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"documentation": {
|
||||
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"seed": "The seed to use when running validation and stabilization passes."
|
||||
},
|
||||
"validate_training": true,
|
||||
"val_split_mode": "automatic",
|
||||
"val_data_root": null,
|
||||
"val_split_proportion": 0.15,
|
||||
"stabilize_training_loss": false,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
"seed": 555
|
||||
}
|
Loading…
Reference in New Issue