diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b99f83c..b5698b4 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -2,8 +2,8 @@ # `nvcc --version` to get CUDA version. # `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # Example Usage: -# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True -# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True import argparse import socket @@ -26,7 +26,6 @@ import numpy as np import json import re import traceback -import shutil try: pynvml.nvmlInit() @@ -39,7 +38,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image, ImageOps -from PIL.Image import Image as Img from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -86,12 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') -parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') -parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') -parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') - args = parser.parse_args() def setup(): @@ -151,137 +145,33 @@ def _sort_by_ratio(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float: return bucket[0] * bucket[1] -class Validation(): - def __init__(self, is_skipped: bool, is_extended: bool) -> None: - if is_skipped: - self.validate = self.__no_op - return print("Validation: Skipped") - - if is_extended: - self.validate = self.__extended_validate - return print("Validation: Extended") - - self.validate = self.__validate - print("Validation: Standard") - - def __validate(self, fp: str) -> bool: - try: - Image.open(fp) - return True - except: - print(f'WARNING: Image cannot be opened: {fp}') - return False - - def __extended_validate(self, fp: str) -> bool: - try: - Image.open(fp).load() - return True - except (OSError) as error: - if 'truncated' in str(error): - print(f'WARNING: Image truncated: {error}') - return False - print(f'WARNING: Image cannot be opened: {error}') - return False - except: - print(f'WARNING: Image cannot be opened: {error}') - return False - - def __no_op(self, fp: str) -> bool: - return True - -class Resize(): - def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None: - if not is_resizing: - self.resize = self.__no_op - return - - if not is_not_migrating: - self.resize = self.__migration - dataset_path = os.path.split(args.dataset) - self.__directory = os.path.join( - dataset_path[0], - f'{dataset_path[1]}_cropped' - ) - os.makedirs(self.__directory, exist_ok=True) - return print(f"Resizing: Performing migration to '{self.__directory}'.") - - self.resize = self.__no_migration - - def __no_migration(self, image_path: str, w: int, h: int) -> Img: - return ImageOps.fit( - Image.open(image_path), - (w, h), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ).convert(mode='RGB') - - def __migration(self, image_path: str, w: int, h: int) -> Img: - filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1]) - - image = ImageOps.fit( - Image.open(image_path), - (w, h), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ).convert(mode='RGB') - - image.save( - os.path.join(f'{self.__directory}', f'{filename}.jpg'), - optimize=True - ) - - try: - shutil.copy( - os.path.join(args.dataset, f'{filename}.txt'), - os.path.join(self.__directory, f'{filename}.txt'), - follow_symlinks=False - ) - except (FileNotFoundError): - f = open( - os.path.join(self.__directory, f'{filename}.txt'), - 'w', - encoding='UTF-8' - ) - f.close() - - return image - - def __no_op(self, image_path: str, w: int, h: int) -> Img: - return Image.open(image_path) - class ImageStore: def __init__(self, data_dir: str) -> None: self.data_dir = data_dir self.image_files = [] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] - - self.validator = Validation( - args.skip_validation, - args.extended_validation - ).validate - - self.resizer = Resize(args.resize, args.no_migration).resize - - self.image_files = [x for x in self.image_files if self.validator(x)] + self.image_files = [x for x in self.image_files if self.__valid_file(x)] def __len__(self) -> int: return len(self.image_files) + def __valid_file(self, f) -> bool: + try: + Image.open(f) + return True + except: + print(f'WARNING: Unable to open file: {f}') + return False + # iterator returns images as PIL images and their index in the store - def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]: + def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]), f + yield Image.open(self.image_files[f]).convert(mode='RGB'), f # get image by index - def get_image(self, ref: Tuple[int, int, int]) -> Img: - return self.resizer( - self.image_files[ref[0]], - ref[1], - ref[2] - ) + def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: + return Image.open(self.image_files[ref[0]]).convert(mode='RGB') # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, ref: Tuple[int, int, int]) -> str: @@ -509,6 +399,15 @@ class AspectDataset(torch.utils.data.Dataset): image_file = self.store.get_image(item) + if args.resize: + image_file = ImageOps.fit( + image_file, + (item[1], item[2]), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ) + return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -700,7 +599,6 @@ def main(): beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, - clip_sample=False ) # load dataset @@ -722,13 +620,6 @@ def main(): collate_fn=dataset.collate_fn ) - # Migrate dataset - if args.resize and not args.no_migration: - for _, batch in enumerate(train_dataloader): - continue - print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") - exit(0) - weight_dtype = torch.float16 if args.fp16 else torch.float32 # move models to device