add fractional support for multiply.txt

This commit is contained in:
Victor Hall 2023-01-22 01:15:50 -05:00
parent 36ece59660
commit 24b00ab35b
3 changed files with 82 additions and 43 deletions

View File

@ -60,18 +60,48 @@ class DataLoaderMultiAspect():
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
def __pick_multiplied_set(self, randomizer):
"""
Deals with multiply.txt whole and fractional numbers
"""
prepared_train_data_local = self.prepared_train_data.copy()
epoch_size = len(self.prepared_train_data)
picked_images = []
# add by whole number part first and decrement multiplier in copy
for iti in prepared_train_data_local:
while iti.multiplier >= 1.0:
picked_images.append(iti)
iti.multiplier -= 1
if iti.multiplier == 0:
prepared_train_data_local.remove(iti)
remaining = epoch_size - len(picked_images)
assert remaining >= 0, "Something went wrong with the multiplier calculation"
# add by renaming fractional numbers by random chance
while remaining > 0:
for iti in prepared_train_data_local:
if randomizer.uniform(0.0, 1) < iti.multiplier:
picked_images.append(iti)
remaining -= 1
prepared_train_data_local.remove(iti)
if remaining <= 0:
break
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0): def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
""" """
returns the current list of images including their captions in a randomized order, returns the current list of images including their captions in a randomized order,
sorted into buckets with same sized images sorted into buckets with same sized images
if dropout_fraction < 1.0, only a subset of the images will be returned if dropout_fraction < 1.0, only a subset of the images will be returned
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
:param dropout_fraction: must be between 0.0 and 1.0. :param dropout_fraction: must be between 0.0 and 1.0.
:return: randomized list of (image, caption) pairs, sorted into same sized buckets :return: randomized list of (image, caption) pairs, sorted into same sized buckets
""" """
"""
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
"""
# TODO: this is not terribly efficient but at least linear time
self.seed += 1 self.seed += 1
randomizer = random.Random(self.seed) randomizer = random.Random(self.seed)
@ -79,7 +109,7 @@ class DataLoaderMultiAspect():
if dropout_fraction < 1.0: if dropout_fraction < 1.0:
picked_images = self.__pick_random_subset(dropout_fraction, randomizer) picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
else: else:
picked_images = self.prepared_train_data picked_images = self.__pick_multiplied_set(randomizer)
randomizer.shuffle(picked_images) randomizer.shuffle(picked_images)
@ -207,6 +237,9 @@ class DataLoaderMultiAspect():
if not self.has_scanned: if not self.has_scanned:
undersized_images = [] undersized_images = []
multipliers = {}
skip_folders = []
for pathname in tqdm.tqdm(image_paths): for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename) caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
@ -216,6 +249,25 @@ class DataLoaderMultiAspect():
txt_file_path = file_path_without_ext + ".txt" txt_file_path = file_path_without_ext + ".txt"
caption_file_path = file_path_without_ext + ".caption" caption_file_path = file_path_without_ext + ".caption"
current_dir = os.path.dirname(pathname)
try:
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
#print(current_dir, multiply_txt_path)
if os.path.exists(multiply_txt_path):
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
else:
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
if os.path.exists(yaml_file_path): if os.path.exists(yaml_file_path):
caption = self.__read_caption_from_yaml(yaml_file_path, caption) caption = self.__read_caption_from_yaml(yaml_file_path, caption)
elif os.path.exists(txt_file_path): elif os.path.exists(txt_file_path):
@ -233,7 +285,13 @@ class DataLoaderMultiAspect():
if width * height < target_wh[0] * target_wh[1]: if width * height < target_wh[0] * target_wh[1]:
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}") undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p) image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
caption=caption,
target_wh=target_wh,
pathname=pathname,
flip_p=flip_p,
multiplier=multipliers[current_dir],
)
decorated_image_train_items.append(image_train_item) decorated_image_train_items.append(image_train_item)
@ -294,25 +352,12 @@ class DataLoaderMultiAspect():
@staticmethod @staticmethod
def __recurse_data_root(self, recurse_root): def __recurse_data_root(self, recurse_root):
multiply = 1
multiply_path = os.path.join(recurse_root, "multiply.txt")
if os.path.exists(multiply_path):
try:
with open(multiply_path, encoding='utf-8', mode='r') as f:
multiply = int(float(f.read().strip()))
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
except:
logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
pass
for f in os.listdir(recurse_root): for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f) current = os.path.join(recurse_root, f)
if os.path.isfile(current): if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower() ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']: if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
# add image multiplyrepeats number of times
for _ in range(multiply):
self.image_paths.append(current) self.image_paths.append(current)
sub_dirs = [] sub_dirs = []

View File

@ -110,13 +110,14 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0) flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images. rating: the relative rating of the images. The rating is measured in comparison to the other images.
""" """
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0): def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
self.caption = caption self.caption = caption
self.target_wh = target_wh self.target_wh = target_wh
self.pathname = pathname self.pathname = pathname
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.cropped_img = None self.cropped_img = None
self.runt_size = 0 self.runt_size = 0
self.multiplier = multiplier
if image is None: if image is None:
self.image = [] self.image = []

View File

@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import os import os
import sys import sys
import math import math
@ -24,6 +25,7 @@ import gc
import random import random
import shutil import shutil
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms import torchvision.transforms as transforms
@ -51,23 +53,9 @@ from data.every_dream import EveryDreamBatch
from utils.convert_diff_to_ckpt import convert as converter from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU from utils.gpu import GPU
_SIGTERM_EXIT_CODE = 130 _SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9 _VERY_LARGE_NUMBER = 1e9
# def is_notebook() -> bool:
# try:
# from IPython import get_ipython
# shell = get_ipython().__class__.__name__
# if shell == 'ZMQInteractiveShell':
# return True # Jupyter notebook or qtconsole
# elif shell == 'TerminalInteractiveShell':
# return False # Terminal running IPython
# else:
# return False # Other type (?)
# except NameError:
# return False # Probably standard Python interpreter
def clean_filename(filename): def clean_filename(filename):
""" """
removes all non-alphanumeric characters from a string so it is safe to use as a filename removes all non-alphanumeric characters from a string so it is safe to use as a filename
@ -275,22 +263,22 @@ def setup_args(args):
return args return args
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
if global_step == 250 or (epoch >= 2 and step == 1): if global_step == 250 or (epoch >= 4 and step == 1):
factor = 1.8 factor = 1.8
scaler.set_growth_factor(factor) scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor) scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50) scaler.set_growth_interval(50)
if global_step == 500 or (epoch >= 4 and step == 1): if global_step == 500 or (epoch >= 8 and step == 1):
factor = 1.6 factor = 1.6
scaler.set_growth_factor(factor) scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor) scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50) scaler.set_growth_interval(50)
if global_step == 1000 or (epoch >= 8 and step == 1): if global_step == 1000 or (epoch >= 10 and step == 1):
factor = 1.3 factor = 1.3
scaler.set_growth_factor(factor) scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor) scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100) scaler.set_growth_interval(100)
if global_step == 3000 or (epoch >= 15 and step == 1): if global_step == 3000 or (epoch >= 20 and step == 1):
factor = 1.15 factor = 1.15
scaler.set_growth_factor(factor) scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor) scaler.set_backoff_factor(1/factor)
@ -379,6 +367,7 @@ def main(args):
safety_checker=None, # save vram safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker feature_extractor=None, # must be none of no safety checker
disable_tqdm=True,
) )
return pipe return pipe
@ -486,10 +475,12 @@ def main(args):
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers") logging.info("Enabled xformers")
except Exception as ex: except Exception as ex:
logging.warning("failed to load xformers, continuing without it") logging.warning("failed to load xformers, using attention slicing instead")
unet.set_attention_slice("auto")
pass pass
else: else:
logging.info("xformers not available or disabled") logging.info("xformers disabled, using attention slicing instead")
unet.set_attention_slice("auto")
default_lr = 2e-6 default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr curr_lr = args.lr if args.lr is not None else default_lr
@ -506,10 +497,10 @@ def main(args):
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters()) params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training: elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters()) params_to_train = itertools.chain(text_encoder.parameters())
else: else:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters()) params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
betas = (0.9, 0.999) betas = (0.9, 0.999)
@ -810,6 +801,7 @@ def main(args):
if (global_step + 1) % args.sample_steps == 0: if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae) pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device) pipe = pipe.to(device)
#pipe.set_progress_bar_config(progress_bar=False)
with torch.no_grad(): with torch.no_grad():
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
@ -853,6 +845,7 @@ def main(args):
loss_local = sum(loss_epoch) / len(loss_epoch) loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
gc.collect()
# end of epoch # end of epoch
# end of training # end of training