Merge pull request #53 from damian0815/fix_validation_multiplier_logic

Fix multiplier issues with validation, and refactor validation logic
This commit is contained in:
Victor Hall 2023-02-08 11:06:46 -05:00 committed by GitHub
commit b6f918daaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 84 deletions

View File

@ -14,13 +14,14 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
from functools import reduce
import logging
import os.path
from collections import defaultdict
import math
import copy
import random
from data.image_train_item import ImageTrainItem, ImageCaption
import PIL
from data.image_train_item import ImageTrainItem
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -38,43 +39,40 @@ class DataLoaderMultiAspect():
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if expected_epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
else:
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer):
def __pick_multiplied_set(self, randomizer: random.Random):
"""
Deals with multiply.txt whole and fractional numbers
"""
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
epoch_size = len(self.prepared_train_data)
picked_images = []
# add by whole number part first and decrement multiplier in copy
for iti in data_copy:
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
while iti.multiplier >= 1.0:
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
for iti in self.prepared_train_data:
multiplier = iti.multiplier
while multiplier >= 1:
picked_images.append(iti)
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
iti.multiplier -= 1.0
multiplier -= 1
# fractional remainders must be dealt with separately
if multiplier > 0:
directory = os.path.dirname(iti.pathname)
fractional_images_per_directory[directory].append(iti)
remaining = epoch_size - len(picked_images)
# resolve fractional parts per-directory
for _, fractional_items in fractional_images_per_directory.items():
randomizer.shuffle(fractional_items)
multiplier = fractional_items[0].multiplier % 1.0
count_to_take = math.ceil(multiplier * len(fractional_items))
picked_images.extend(fractional_items[:count_to_take])
assert remaining >= 0, "Something went wrong with the multiplier calculation"
# add by remaining fractional numbers by random chance
while remaining > 0:
for iti in data_copy:
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
picked_images.append(iti)
remaining -= 1
iti.multiplier = 0.0
if remaining <= 0:
break
del data_copy
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
@ -110,20 +108,19 @@ class DataLoaderMultiAspect():
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
if len(buckets) > 1:
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
while len(runt_bucket) < batch_size:
runt_bucket.append(random.choice(runt_bucket))
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
while len(runt_bucket) < batch_size:
runt_bucket.append(random.choice(runt_bucket))
current_bucket_size = len(buckets[bucket])
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# flatten the buckets
items: list[ImageTrainItem] = []

View File

@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset):
num_images = len(self.image_train_items)
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
self.__update_image_train_items(1.0)
return items
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1

View File

@ -1,7 +1,9 @@
import copy
import json
import logging
import math
import random
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Generator
from argparse import Namespace
import torch
@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
for i in items:
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
default_batch_size: int,
resolution: int,
log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
self.resolution = resolution
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'validate_training': True,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
@ -120,21 +128,24 @@ class EveryDreamValidator:
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode']
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items
if val_split_mode == 'none':
if val_split_mode is None or val_split_mode == 'none':
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
elif val_split_mode == 'manual':
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0,
seed=self.seed,
)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
@ -149,6 +160,7 @@ class EveryDreamValidator:
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)

View File

@ -263,7 +263,7 @@ class ImageTrainItem:
self.multiplier = multiplier
self.image_size = None
if image is None:
if image is None or len(image) == 0:
self.image = []
else:
self.image = image

View File

@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver):
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
multipliers[current_dir] = 1.0
@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver):
caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
items.append(item)
cur_file_multiplier = multipliers[current_dir]
while cur_file_multiplier >= 1.0:
items.append(item)
cur_file_multiplier -= 1
if cur_file_multiplier > 0:
if randomizer.random() < cur_file_multiplier:
items.append(item)
return items
@staticmethod

View File

@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU
if torch.cuda.is_available():
from utils.gpu import GPU
import data.aspects as aspects
import data.resolver as resolver
@ -159,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
"""
updates the vram usage for the epoch
"""
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if gpu is not None:
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def set_args_12gb(args):
@ -326,8 +328,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
# Remove erroneous items
image_train_items = [item for item in resolved_items if item.error is None]
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
return image_train_items
@ -372,6 +373,7 @@ def main(args):
else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
@ -548,6 +550,7 @@ def main(args):
except Exception as e:
traceback.print_exc()
logging.error(" * Failed to load checkpoint *")
raise
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@ -620,9 +623,13 @@ def main(args):
image_train_items = resolve_image_train_items(args, log_folder)
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
validator = EveryDreamValidator(args.validation_config,
default_batch_size=args.batch_size,
resolution=args.resolution,
log_writer=log_writer,
)
# the validation dataset may need to steal some items from image_train_items
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
@ -710,8 +717,9 @@ def main(args):
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
if gpu is not None:
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
@ -940,7 +948,7 @@ def main(args):
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch
@ -1021,6 +1029,7 @@ if __name__ == "__main__":
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")