diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d8e00d --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*_cropped/ +**/nsfw-ids.txt +**/*.image +**/*.caption +**/dataset*.tar +**/*.json +**/*.png +**/*.jpg diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index cc645a7..9a6e2b7 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -2,8 +2,8 @@ # `nvcc --version` to get CUDA version. # `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # Example Usage: -# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True -# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam --gradient_checkpointing --batch_size=1 --fp16 --image_log_steps=250 --epochs=20 --resolution=768 --use_ema +# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam --gradient_checkpointing --batch_size=10 --fp16 --image_log_steps=250 --epochs=20 --resolution=768 --use_ema import argparse import socket @@ -26,6 +26,7 @@ import numpy as np import json import re import traceback +import shutil try: pynvml.nvmlInit() @@ -38,6 +39,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image, ImageOps +from PIL.Image import Image as Img from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -84,6 +86,10 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') +parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') +parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') +parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') + args = parser.parse_args() def setup(): @@ -143,33 +149,137 @@ def _sort_by_ratio(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float: return bucket[0] * bucket[1] +class Validation(): + def __init__(self, is_skipped: bool, is_extended: bool) -> None: + if is_skipped: + self.validate = self.__no_op + return print("Validation: Skipped") + + if is_extended: + self.validate = self.__extended_validate + return print("Validation: Extended") + + self.validate = self.__validate + print("Validation: Standard") + + def __validate(self, fp: str) -> bool: + try: + Image.open(fp) + return True + except: + print(f'WARNING: Image cannot be opened: {fp}') + return False + + def __extended_validate(self, fp: str) -> bool: + try: + Image.open(fp).load() + return True + except (OSError) as error: + if 'truncated' in str(error): + print(f'WARNING: Image truncated: {error}') + return False + print(f'WARNING: Image cannot be opened: {error}') + return False + except: + print(f'WARNING: Image cannot be opened: {error}') + return False + + def __no_op(self, fp: str) -> bool: + return True + +class Resize(): + def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None: + if not is_resizing: + self.resize = self.__no_op + return + + if not is_not_migrating: + self.resize = self.__migration + dataset_path = os.path.split(args.dataset) + self.__directory = os.path.join( + dataset_path[0], + f'{dataset_path[1]}_cropped' + ) + os.makedirs(self.__directory, exist_ok=True) + return print(f"Resizing: Performing migration to '{self.__directory}'.") + + self.resize = self.__no_migration + + def __no_migration(self, image_path: str, w: int, h: int) -> Img: + return ImageOps.fit( + Image.open(image_path), + (w, h), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ).convert(mode='RGB') + + def __migration(self, image_path: str, w: int, h: int) -> Img: + filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1]) + + image = ImageOps.fit( + Image.open(image_path), + (w, h), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ).convert(mode='RGB') + + image.save( + os.path.join(f'{self.__directory}', f'{filename}.jpg'), + optimize=True + ) + + try: + shutil.copy( + os.path.join(args.dataset, f'{filename}.txt'), + os.path.join(self.__directory, f'{filename}.txt'), + follow_symlinks=False + ) + except (FileNotFoundError): + f = open( + os.path.join(self.__directory, f'{filename}.txt'), + 'w', + encoding='UTF-8' + ) + f.close() + + return image + + def __no_op(self, image_path: str, w: int, h: int) -> Img: + return Image.open(image_path) + class ImageStore: def __init__(self, data_dir: str) -> None: self.data_dir = data_dir self.image_files = [] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] - self.image_files = [x for x in self.image_files if self.__valid_file(x)] + + self.validator = Validation( + args.skip_validation, + args.extended_validation + ).validate + + self.resizer = Resize(args.resize, args.no_migration).resize + + self.image_files = [x for x in self.image_files if self.validator(x)] def __len__(self) -> int: return len(self.image_files) - def __valid_file(self, f) -> bool: - try: - Image.open(f) - return True - except: - print(f'WARNING: Unable to open file: {f}') - return False - # iterator returns images as PIL images and their index in the store - def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: + def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]).convert(mode='RGB'), f + yield Image.open(self.image_files[f]), f # get image by index - def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: - return Image.open(self.image_files[ref[0]]).convert(mode='RGB') + def get_image(self, ref: Tuple[int, int, int]) -> Img: + return self.resizer( + self.image_files[ref[0]], + ref[1], + ref[2] + ) # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, ref: Tuple[int, int, int]) -> str: @@ -397,15 +507,6 @@ class AspectDataset(torch.utils.data.Dataset): image_file = self.store.get_image(item) - if args.resize: - image_file = ImageOps.fit( - image_file, - (item[1], item[2]), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ) - return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -609,6 +710,13 @@ def main(): num_workers=0, collate_fn=dataset.collate_fn ) + + # Migrate dataset + if args.resize and not args.no_migration: + for _, batch in enumerate(train_dataloader): + continue + print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") + exit(0) weight_dtype = torch.float16 if args.fp16 else torch.float32