Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer into main
This commit is contained in:
commit
f275f20412
|
@ -101,7 +101,7 @@
|
|||
"!pip install -q protobuf==3.20.1\n",
|
||||
"!pip install -q wandb==0.13.6\n",
|
||||
"!pip install -q pyre-extensions==0.0.23\n",
|
||||
"!pip install -q xformers==0.0.17.dev435\n",
|
||||
"!pip install -q xformers==0.0.16\n",
|
||||
"!pip install -q pytorch-lightning==1.6.5\n",
|
||||
"!pip install -q OmegaConf==2.2.3\n",
|
||||
"!pip install -q numpy==1.23.5\n",
|
||||
|
|
|
@ -14,13 +14,14 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
import bisect
|
||||
from functools import reduce
|
||||
import logging
|
||||
import os.path
|
||||
from collections import defaultdict
|
||||
import math
|
||||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
import PIL
|
||||
from data.image_train_item import ImageTrainItem
|
||||
import PIL.Image
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
||||
|
@ -38,43 +39,40 @@ class DataLoaderMultiAspect():
|
|||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||
if expected_epoch_size != len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
|
||||
else:
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
|
||||
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
self.__update_rating_sums()
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
|
||||
def __pick_multiplied_set(self, randomizer: random.Random):
|
||||
"""
|
||||
Deals with multiply.txt whole and fractional numbers
|
||||
"""
|
||||
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
|
||||
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
|
||||
epoch_size = len(self.prepared_train_data)
|
||||
picked_images = []
|
||||
|
||||
# add by whole number part first and decrement multiplier in copy
|
||||
for iti in data_copy:
|
||||
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
|
||||
while iti.multiplier >= 1.0:
|
||||
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
|
||||
for iti in self.prepared_train_data:
|
||||
multiplier = iti.multiplier
|
||||
while multiplier >= 1:
|
||||
picked_images.append(iti)
|
||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
|
||||
iti.multiplier -= 1.0
|
||||
multiplier -= 1
|
||||
# fractional remainders must be dealt with separately
|
||||
if multiplier > 0:
|
||||
directory = os.path.dirname(iti.pathname)
|
||||
fractional_images_per_directory[directory].append(iti)
|
||||
|
||||
remaining = epoch_size - len(picked_images)
|
||||
# resolve fractional parts per-directory
|
||||
for _, fractional_items in fractional_images_per_directory.items():
|
||||
randomizer.shuffle(fractional_items)
|
||||
multiplier = fractional_items[0].multiplier % 1.0
|
||||
count_to_take = math.ceil(multiplier * len(fractional_items))
|
||||
picked_images.extend(fractional_items[:count_to_take])
|
||||
|
||||
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
||||
|
||||
# add by remaining fractional numbers by random chance
|
||||
while remaining > 0:
|
||||
for iti in data_copy:
|
||||
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
|
||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
|
||||
picked_images.append(iti)
|
||||
remaining -= 1
|
||||
iti.multiplier = 0.0
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
del data_copy
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
||||
|
@ -110,20 +108,19 @@ class DataLoaderMultiAspect():
|
|||
buckets[(target_wh[0],target_wh[1])] = []
|
||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||
|
||||
if len(buckets) > 1:
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
while len(runt_bucket) < batch_size:
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
while len(runt_bucket) < batch_size:
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
items: list[ImageTrainItem] = []
|
||||
|
|
|
@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset):
|
|||
num_images = len(self.image_train_items)
|
||||
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
|
||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
||||
self.__update_image_train_items(1.0)
|
||||
return items
|
||||
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Callable, Any, Optional, Generator
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
|
@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch
|
|||
remaining_items = list(items_copy[split_item_count:])
|
||||
return split_items, remaining_items
|
||||
|
||||
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
|
||||
for i in items:
|
||||
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
default_batch_size: int,
|
||||
resolution: int,
|
||||
log_writer: SummaryWriter):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
|
||||
self.log_writer = log_writer
|
||||
self.resolution = resolution
|
||||
|
||||
self.config = {
|
||||
'batch_size': default_batch_size,
|
||||
'every_n_epochs': 1,
|
||||
'seed': 555,
|
||||
|
||||
'validate_training': True,
|
||||
'val_split_mode': 'automatic',
|
||||
'val_split_proportion': 0.15,
|
||||
|
||||
|
@ -120,21 +128,24 @@ class EveryDreamValidator:
|
|||
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config['val_split_mode']
|
||||
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
||||
val_split_proportion = self.config['val_split_proportion']
|
||||
remaining_train_items = image_train_items
|
||||
if val_split_mode == 'none':
|
||||
if val_split_mode is None or val_split_mode == 'none':
|
||||
return None, image_train_items
|
||||
elif val_split_mode == 'automatic':
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||
val_items = list(disable_multiplier_and_flip(val_items))
|
||||
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||
elif val_split_mode == 'manual':
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(512),
|
||||
aspects=aspects.get_aspect_buckets(self.resolution),
|
||||
flip_p=0.0,
|
||||
seed=self.seed,
|
||||
)
|
||||
val_data_root = self.config['val_data_root']
|
||||
val_items = resolver.resolve_root(val_data_root, args)
|
||||
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
|
@ -149,6 +160,7 @@ class EveryDreamValidator:
|
|||
|
||||
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||
name='stabilize-train')
|
||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
|
|
|
@ -263,7 +263,7 @@ class ImageTrainItem:
|
|||
self.multiplier = multiplier
|
||||
|
||||
self.image_size = None
|
||||
if image is None:
|
||||
if image is None or len(image) == 0:
|
||||
self.image = []
|
||||
else:
|
||||
self.image = image
|
||||
|
|
|
@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver):
|
|||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
multipliers[current_dir] = 1.0
|
||||
|
@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver):
|
|||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
items.append(item)
|
||||
|
||||
cur_file_multiplier = multipliers[current_dir]
|
||||
|
||||
while cur_file_multiplier >= 1.0:
|
||||
items.append(item)
|
||||
cur_file_multiplier -= 1
|
||||
|
||||
if cur_file_multiplier > 0:
|
||||
if randomizer.random() < cur_file_multiplier:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
|
|
51
train.py
51
train.py
|
@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
|
|||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
|
||||
|
@ -159,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
|||
"""
|
||||
updates the vram usage for the epoch
|
||||
"""
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
if gpu is not None:
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
|
||||
|
||||
def set_args_12gb(args):
|
||||
|
@ -326,8 +328,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
|
|||
|
||||
# Remove erroneous items
|
||||
image_train_items = [item for item in resolved_items if item.error is None]
|
||||
|
||||
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
||||
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
|
||||
|
||||
return image_train_items
|
||||
|
||||
|
@ -372,6 +373,7 @@ def main(args):
|
|||
else:
|
||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||
device = 'cpu'
|
||||
gpu = None
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
|
||||
|
@ -548,6 +550,7 @@ def main(args):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logging.error(" * Failed to load checkpoint *")
|
||||
raise
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
@ -620,9 +623,13 @@ def main(args):
|
|||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||
validator = EveryDreamValidator(args.validation_config,
|
||||
default_batch_size=args.batch_size,
|
||||
resolution=args.resolution,
|
||||
log_writer=log_writer,
|
||||
)
|
||||
# the validation dataset may need to steal some items from image_train_items
|
||||
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
|
@ -710,8 +717,9 @@ def main(args):
|
|||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
os.makedirs(f"{log_folder}/samples/")
|
||||
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||
if gpu is not None:
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
@ -940,7 +948,7 @@ def main(args):
|
|||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
# validate
|
||||
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
@ -1021,6 +1029,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
|
Loading…
Reference in New Issue