add fractional support for multiply.txt
This commit is contained in:
parent
36ece59660
commit
24b00ab35b
|
@ -60,18 +60,48 @@ class DataLoaderMultiAspect():
|
|||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
"""
|
||||
Deals with multiply.txt whole and fractional numbers
|
||||
"""
|
||||
prepared_train_data_local = self.prepared_train_data.copy()
|
||||
epoch_size = len(self.prepared_train_data)
|
||||
picked_images = []
|
||||
|
||||
# add by whole number part first and decrement multiplier in copy
|
||||
for iti in prepared_train_data_local:
|
||||
while iti.multiplier >= 1.0:
|
||||
picked_images.append(iti)
|
||||
iti.multiplier -= 1
|
||||
if iti.multiplier == 0:
|
||||
prepared_train_data_local.remove(iti)
|
||||
|
||||
remaining = epoch_size - len(picked_images)
|
||||
|
||||
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
||||
|
||||
# add by renaming fractional numbers by random chance
|
||||
while remaining > 0:
|
||||
for iti in prepared_train_data_local:
|
||||
if randomizer.uniform(0.0, 1) < iti.multiplier:
|
||||
picked_images.append(iti)
|
||||
remaining -= 1
|
||||
prepared_train_data_local.remove(iti)
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||
"""
|
||||
returns the current list of images including their captions in a randomized order,
|
||||
sorted into buckets with same sized images
|
||||
if dropout_fraction < 1.0, only a subset of the images will be returned
|
||||
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
|
||||
:param dropout_fraction: must be between 0.0 and 1.0.
|
||||
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
||||
"""
|
||||
"""
|
||||
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
|
||||
"""
|
||||
# TODO: this is not terribly efficient but at least linear time
|
||||
|
||||
self.seed += 1
|
||||
randomizer = random.Random(self.seed)
|
||||
|
@ -79,7 +109,7 @@ class DataLoaderMultiAspect():
|
|||
if dropout_fraction < 1.0:
|
||||
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
|
||||
else:
|
||||
picked_images = self.prepared_train_data
|
||||
picked_images = self.__pick_multiplied_set(randomizer)
|
||||
|
||||
randomizer.shuffle(picked_images)
|
||||
|
||||
|
@ -207,6 +237,9 @@ class DataLoaderMultiAspect():
|
|||
if not self.has_scanned:
|
||||
undersized_images = []
|
||||
|
||||
multipliers = {}
|
||||
skip_folders = []
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||
|
@ -216,6 +249,25 @@ class DataLoaderMultiAspect():
|
|||
txt_file_path = file_path_without_ext + ".txt"
|
||||
caption_file_path = file_path_without_ext + ".caption"
|
||||
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
try:
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
#print(current_dir, multiply_txt_path)
|
||||
if os.path.exists(multiply_txt_path):
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
else:
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
if os.path.exists(yaml_file_path):
|
||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||
elif os.path.exists(txt_file_path):
|
||||
|
@ -233,7 +285,13 @@ class DataLoaderMultiAspect():
|
|||
if width * height < target_wh[0] * target_wh[1]:
|
||||
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
|
||||
|
||||
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
|
||||
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
|
||||
caption=caption,
|
||||
target_wh=target_wh,
|
||||
pathname=pathname,
|
||||
flip_p=flip_p,
|
||||
multiplier=multipliers[current_dir],
|
||||
)
|
||||
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
|
||||
|
@ -294,25 +352,12 @@ class DataLoaderMultiAspect():
|
|||
|
||||
@staticmethod
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
multiply = 1
|
||||
multiply_path = os.path.join(recurse_root, "multiply.txt")
|
||||
if os.path.exists(multiply_path):
|
||||
try:
|
||||
with open(multiply_path, encoding='utf-8', mode='r') as f:
|
||||
multiply = int(float(f.read().strip()))
|
||||
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
|
||||
except:
|
||||
logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
|
||||
pass
|
||||
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
# add image multiplyrepeats number of times
|
||||
for _ in range(multiply):
|
||||
self.image_paths.append(current)
|
||||
|
||||
sub_dirs = []
|
||||
|
|
|
@ -110,13 +110,14 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
self.pathname = pathname
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
self.cropped_img = None
|
||||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
|
||||
if image is None:
|
||||
self.image = []
|
||||
|
|
37
train.py
37
train.py
|
@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
@ -24,6 +25,7 @@ import gc
|
|||
import random
|
||||
import shutil
|
||||
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torchvision.transforms as transforms
|
||||
|
@ -51,23 +53,9 @@ from data.every_dream import EveryDreamBatch
|
|||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
|
||||
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
# def is_notebook() -> bool:
|
||||
# try:
|
||||
# from IPython import get_ipython
|
||||
# shell = get_ipython().__class__.__name__
|
||||
# if shell == 'ZMQInteractiveShell':
|
||||
# return True # Jupyter notebook or qtconsole
|
||||
# elif shell == 'TerminalInteractiveShell':
|
||||
# return False # Terminal running IPython
|
||||
# else:
|
||||
# return False # Other type (?)
|
||||
# except NameError:
|
||||
# return False # Probably standard Python interpreter
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
|
@ -275,22 +263,22 @@ def setup_args(args):
|
|||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 250 or (epoch >= 2 and step == 1):
|
||||
if global_step == 250 or (epoch >= 4 and step == 1):
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 500 or (epoch >= 4 and step == 1):
|
||||
if global_step == 500 or (epoch >= 8 and step == 1):
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000 or (epoch >= 8 and step == 1):
|
||||
if global_step == 1000 or (epoch >= 10 and step == 1):
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 3000 or (epoch >= 15 and step == 1):
|
||||
if global_step == 3000 or (epoch >= 20 and step == 1):
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
|
@ -379,6 +367,7 @@ def main(args):
|
|||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
disable_tqdm=True,
|
||||
)
|
||||
|
||||
return pipe
|
||||
|
@ -486,10 +475,12 @@ def main(args):
|
|||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
except Exception as ex:
|
||||
logging.warning("failed to load xformers, continuing without it")
|
||||
logging.warning("failed to load xformers, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
pass
|
||||
else:
|
||||
logging.info("xformers not available or disabled")
|
||||
logging.info("xformers disabled, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
@ -506,10 +497,10 @@ def main(args):
|
|||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
|
@ -810,6 +801,7 @@ def main(args):
|
|||
if (global_step + 1) % args.sample_steps == 0:
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||
pipe = pipe.to(device)
|
||||
#pipe.set_progress_bar_config(progress_bar=False)
|
||||
|
||||
with torch.no_grad():
|
||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||
|
@ -853,6 +845,7 @@ def main(args):
|
|||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
# end of training
|
||||
|
|
Loading…
Reference in New Issue