add fractional support for multiply.txt
This commit is contained in:
parent
36ece59660
commit
24b00ab35b
|
@ -60,18 +60,48 @@ class DataLoaderMultiAspect():
|
||||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
||||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||||
|
|
||||||
|
|
||||||
|
def __pick_multiplied_set(self, randomizer):
|
||||||
|
"""
|
||||||
|
Deals with multiply.txt whole and fractional numbers
|
||||||
|
"""
|
||||||
|
prepared_train_data_local = self.prepared_train_data.copy()
|
||||||
|
epoch_size = len(self.prepared_train_data)
|
||||||
|
picked_images = []
|
||||||
|
|
||||||
|
# add by whole number part first and decrement multiplier in copy
|
||||||
|
for iti in prepared_train_data_local:
|
||||||
|
while iti.multiplier >= 1.0:
|
||||||
|
picked_images.append(iti)
|
||||||
|
iti.multiplier -= 1
|
||||||
|
if iti.multiplier == 0:
|
||||||
|
prepared_train_data_local.remove(iti)
|
||||||
|
|
||||||
|
remaining = epoch_size - len(picked_images)
|
||||||
|
|
||||||
|
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
||||||
|
|
||||||
|
# add by renaming fractional numbers by random chance
|
||||||
|
while remaining > 0:
|
||||||
|
for iti in prepared_train_data_local:
|
||||||
|
if randomizer.uniform(0.0, 1) < iti.multiplier:
|
||||||
|
picked_images.append(iti)
|
||||||
|
remaining -= 1
|
||||||
|
prepared_train_data_local.remove(iti)
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return picked_images
|
||||||
|
|
||||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||||
"""
|
"""
|
||||||
returns the current list of images including their captions in a randomized order,
|
returns the current list of images including their captions in a randomized order,
|
||||||
sorted into buckets with same sized images
|
sorted into buckets with same sized images
|
||||||
if dropout_fraction < 1.0, only a subset of the images will be returned
|
if dropout_fraction < 1.0, only a subset of the images will be returned
|
||||||
|
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
|
||||||
:param dropout_fraction: must be between 0.0 and 1.0.
|
:param dropout_fraction: must be between 0.0 and 1.0.
|
||||||
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
||||||
"""
|
"""
|
||||||
"""
|
|
||||||
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
|
|
||||||
"""
|
|
||||||
# TODO: this is not terribly efficient but at least linear time
|
|
||||||
|
|
||||||
self.seed += 1
|
self.seed += 1
|
||||||
randomizer = random.Random(self.seed)
|
randomizer = random.Random(self.seed)
|
||||||
|
@ -79,7 +109,7 @@ class DataLoaderMultiAspect():
|
||||||
if dropout_fraction < 1.0:
|
if dropout_fraction < 1.0:
|
||||||
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
|
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
|
||||||
else:
|
else:
|
||||||
picked_images = self.prepared_train_data
|
picked_images = self.__pick_multiplied_set(randomizer)
|
||||||
|
|
||||||
randomizer.shuffle(picked_images)
|
randomizer.shuffle(picked_images)
|
||||||
|
|
||||||
|
@ -207,6 +237,9 @@ class DataLoaderMultiAspect():
|
||||||
if not self.has_scanned:
|
if not self.has_scanned:
|
||||||
undersized_images = []
|
undersized_images = []
|
||||||
|
|
||||||
|
multipliers = {}
|
||||||
|
skip_folders = []
|
||||||
|
|
||||||
for pathname in tqdm.tqdm(image_paths):
|
for pathname in tqdm.tqdm(image_paths):
|
||||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||||
|
@ -216,6 +249,25 @@ class DataLoaderMultiAspect():
|
||||||
txt_file_path = file_path_without_ext + ".txt"
|
txt_file_path = file_path_without_ext + ".txt"
|
||||||
caption_file_path = file_path_without_ext + ".caption"
|
caption_file_path = file_path_without_ext + ".caption"
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(pathname)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if current_dir not in multipliers:
|
||||||
|
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||||
|
#print(current_dir, multiply_txt_path)
|
||||||
|
if os.path.exists(multiply_txt_path):
|
||||||
|
with open(multiply_txt_path, 'r') as f:
|
||||||
|
val = float(f.read().strip())
|
||||||
|
multipliers[current_dir] = val
|
||||||
|
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||||
|
else:
|
||||||
|
skip_folders.append(current_dir)
|
||||||
|
multipliers[current_dir] = 1.0
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||||
|
skip_folders.append(current_dir)
|
||||||
|
multipliers[current_dir] = 1.0
|
||||||
|
|
||||||
if os.path.exists(yaml_file_path):
|
if os.path.exists(yaml_file_path):
|
||||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||||
elif os.path.exists(txt_file_path):
|
elif os.path.exists(txt_file_path):
|
||||||
|
@ -233,7 +285,13 @@ class DataLoaderMultiAspect():
|
||||||
if width * height < target_wh[0] * target_wh[1]:
|
if width * height < target_wh[0] * target_wh[1]:
|
||||||
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
|
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
|
||||||
|
|
||||||
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
|
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
|
||||||
|
caption=caption,
|
||||||
|
target_wh=target_wh,
|
||||||
|
pathname=pathname,
|
||||||
|
flip_p=flip_p,
|
||||||
|
multiplier=multipliers[current_dir],
|
||||||
|
)
|
||||||
|
|
||||||
decorated_image_train_items.append(image_train_item)
|
decorated_image_train_items.append(image_train_item)
|
||||||
|
|
||||||
|
@ -294,25 +352,12 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __recurse_data_root(self, recurse_root):
|
def __recurse_data_root(self, recurse_root):
|
||||||
multiply = 1
|
|
||||||
multiply_path = os.path.join(recurse_root, "multiply.txt")
|
|
||||||
if os.path.exists(multiply_path):
|
|
||||||
try:
|
|
||||||
with open(multiply_path, encoding='utf-8', mode='r') as f:
|
|
||||||
multiply = int(float(f.read().strip()))
|
|
||||||
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
|
|
||||||
except:
|
|
||||||
logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
|
|
||||||
pass
|
|
||||||
|
|
||||||
for f in os.listdir(recurse_root):
|
for f in os.listdir(recurse_root):
|
||||||
current = os.path.join(recurse_root, f)
|
current = os.path.join(recurse_root, f)
|
||||||
|
|
||||||
if os.path.isfile(current):
|
if os.path.isfile(current):
|
||||||
ext = os.path.splitext(f)[1].lower()
|
ext = os.path.splitext(f)[1].lower()
|
||||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||||
# add image multiplyrepeats number of times
|
|
||||||
for _ in range(multiply):
|
|
||||||
self.image_paths.append(current)
|
self.image_paths.append(current)
|
||||||
|
|
||||||
sub_dirs = []
|
sub_dirs = []
|
||||||
|
|
|
@ -110,13 +110,14 @@ class ImageTrainItem:
|
||||||
flip_p: probability of flipping image (0.0 to 1.0)
|
flip_p: probability of flipping image (0.0 to 1.0)
|
||||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||||
"""
|
"""
|
||||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
|
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||||
self.caption = caption
|
self.caption = caption
|
||||||
self.target_wh = target_wh
|
self.target_wh = target_wh
|
||||||
self.pathname = pathname
|
self.pathname = pathname
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
self.cropped_img = None
|
self.cropped_img = None
|
||||||
self.runt_size = 0
|
self.runt_size = 0
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
self.image = []
|
self.image = []
|
||||||
|
|
37
train.py
37
train.py
|
@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import math
|
import math
|
||||||
|
@ -24,6 +25,7 @@ import gc
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
@ -51,23 +53,9 @@ from data.every_dream import EveryDreamBatch
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
from utils.gpu import GPU
|
from utils.gpu import GPU
|
||||||
|
|
||||||
|
|
||||||
_SIGTERM_EXIT_CODE = 130
|
_SIGTERM_EXIT_CODE = 130
|
||||||
_VERY_LARGE_NUMBER = 1e9
|
_VERY_LARGE_NUMBER = 1e9
|
||||||
|
|
||||||
# def is_notebook() -> bool:
|
|
||||||
# try:
|
|
||||||
# from IPython import get_ipython
|
|
||||||
# shell = get_ipython().__class__.__name__
|
|
||||||
# if shell == 'ZMQInteractiveShell':
|
|
||||||
# return True # Jupyter notebook or qtconsole
|
|
||||||
# elif shell == 'TerminalInteractiveShell':
|
|
||||||
# return False # Terminal running IPython
|
|
||||||
# else:
|
|
||||||
# return False # Other type (?)
|
|
||||||
# except NameError:
|
|
||||||
# return False # Probably standard Python interpreter
|
|
||||||
|
|
||||||
def clean_filename(filename):
|
def clean_filename(filename):
|
||||||
"""
|
"""
|
||||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||||
|
@ -275,22 +263,22 @@ def setup_args(args):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
if global_step == 250 or (epoch >= 2 and step == 1):
|
if global_step == 250 or (epoch >= 4 and step == 1):
|
||||||
factor = 1.8
|
factor = 1.8
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(50)
|
scaler.set_growth_interval(50)
|
||||||
if global_step == 500 or (epoch >= 4 and step == 1):
|
if global_step == 500 or (epoch >= 8 and step == 1):
|
||||||
factor = 1.6
|
factor = 1.6
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(50)
|
scaler.set_growth_interval(50)
|
||||||
if global_step == 1000 or (epoch >= 8 and step == 1):
|
if global_step == 1000 or (epoch >= 10 and step == 1):
|
||||||
factor = 1.3
|
factor = 1.3
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(100)
|
scaler.set_growth_interval(100)
|
||||||
if global_step == 3000 or (epoch >= 15 and step == 1):
|
if global_step == 3000 or (epoch >= 20 and step == 1):
|
||||||
factor = 1.15
|
factor = 1.15
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
|
@ -379,6 +367,7 @@ def main(args):
|
||||||
safety_checker=None, # save vram
|
safety_checker=None, # save vram
|
||||||
requires_safety_checker=None, # avoid nag
|
requires_safety_checker=None, # avoid nag
|
||||||
feature_extractor=None, # must be none of no safety checker
|
feature_extractor=None, # must be none of no safety checker
|
||||||
|
disable_tqdm=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
@ -486,10 +475,12 @@ def main(args):
|
||||||
unet.enable_xformers_memory_efficient_attention()
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
logging.info("Enabled xformers")
|
logging.info("Enabled xformers")
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.warning("failed to load xformers, continuing without it")
|
logging.warning("failed to load xformers, using attention slicing instead")
|
||||||
|
unet.set_attention_slice("auto")
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logging.info("xformers not available or disabled")
|
logging.info("xformers disabled, using attention slicing instead")
|
||||||
|
unet.set_attention_slice("auto")
|
||||||
|
|
||||||
default_lr = 2e-6
|
default_lr = 2e-6
|
||||||
curr_lr = args.lr if args.lr is not None else default_lr
|
curr_lr = args.lr if args.lr is not None else default_lr
|
||||||
|
@ -506,10 +497,10 @@ def main(args):
|
||||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters())
|
params_to_train = itertools.chain(unet.parameters())
|
||||||
elif args.disable_unet_training:
|
elif args.disable_unet_training:
|
||||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(text_encoder.parameters())
|
params_to_train = itertools.chain(text_encoder.parameters())
|
||||||
else:
|
else:
|
||||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
|
|
||||||
betas = (0.9, 0.999)
|
betas = (0.9, 0.999)
|
||||||
|
@ -810,6 +801,7 @@ def main(args):
|
||||||
if (global_step + 1) % args.sample_steps == 0:
|
if (global_step + 1) % args.sample_steps == 0:
|
||||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
|
#pipe.set_progress_bar_config(progress_bar=False)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||||
|
@ -853,6 +845,7 @@ def main(args):
|
||||||
|
|
||||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
gc.collect()
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
# end of training
|
# end of training
|
||||||
|
|
Loading…
Reference in New Issue