fix multiplier issues with validation and refactor validation logic

This commit is contained in:
Damian Stewart 2023-02-08 11:28:45 +01:00
parent 1068d2dd2a
commit 4e37200dda
6 changed files with 44 additions and 53 deletions

View File

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
import logging
from functools import reduce
import math
import copy
@ -41,40 +42,27 @@ class DataLoaderMultiAspect():
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
self.__update_rating_sums()
count_including_multipliers = sum([math.floor(max(i.multiplier, 1)) for i in self.prepared_train_data])
if count_including_multipliers > len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} items ({count_including_multipliers} items total after applying multipliers)")
else:
logging.info(f" * DLMA initialized with {len(image_train_items)} items")
def __pick_multiplied_set(self, randomizer):
"""
Deals with multiply.txt whole and fractional numbers
"""
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
epoch_size = len(self.prepared_train_data)
picked_images = []
# add by whole number part first and decrement multiplier in copy
for iti in data_copy:
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
while iti.multiplier >= 1.0:
for iti in self.prepared_train_data:
multiplier = iti.multiplier
while multiplier >= 1:
picked_images.append(iti)
multiplier -= 1
# deal with fractional remainder
if multiplier > randomizer.uniform(0, 1):
picked_images.append(iti)
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
iti.multiplier -= 1.0
remaining = epoch_size - len(picked_images)
assert remaining >= 0, "Something went wrong with the multiplier calculation"
# add by remaining fractional numbers by random chance
while remaining > 0:
for iti in data_copy:
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
picked_images.append(iti)
remaining -= 1
iti.multiplier = 0.0
if remaining <= 0:
break
del data_copy
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:

View File

@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset):
num_images = len(self.image_train_items)
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
self.__update_image_train_items(1.0)
return items
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1

View File

@ -1,7 +1,9 @@
import copy
import json
import logging
import math
import random
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Generator
from argparse import Namespace
import torch
@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
for i in items:
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
default_batch_size: int,
resolution: int,
log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
self.resolution = resolution
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'validate_training': True,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
@ -120,21 +128,24 @@ class EveryDreamValidator:
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode']
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items
if val_split_mode == 'none':
if val_split_mode is None or val_split_mode == 'none':
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
elif val_split_mode == 'manual':
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0,
seed=self.seed,
)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
@ -149,6 +160,7 @@ class EveryDreamValidator:
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)

View File

@ -263,7 +263,7 @@ class ImageTrainItem:
self.multiplier = multiplier
self.image_size = None
if image is None:
if image is None or len(image) == 0:
self.image = []
else:
self.image = image

View File

@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver):
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
multipliers[current_dir] = 1.0
@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver):
caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
cur_file_multiplier = multipliers[current_dir]
items.append(item)
while cur_file_multiplier >= 1.0:
items.append(item)
cur_file_multiplier -= 1
if cur_file_multiplier > 0:
if randomizer.random() < cur_file_multiplier:
items.append(item)
return items
@staticmethod

View File

@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU
if torch.cuda.is_available():
from utils.gpu import GPU
import data.aspects as aspects
import data.resolver as resolver
@ -326,8 +327,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
# Remove erroneous items
image_train_items = [item for item in resolved_items if item.error is None]
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
return image_train_items
@ -620,9 +620,13 @@ def main(args):
image_train_items = resolve_image_train_items(args, log_folder)
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
validator = EveryDreamValidator(args.validation_config,
default_batch_size=args.batch_size,
resolution=args.resolution,
log_writer=log_writer,
)
# the validation dataset may need to steal some items from image_train_items
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
@ -940,7 +944,7 @@ def main(args):
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch
@ -1021,6 +1025,7 @@ if __name__ == "__main__":
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")