From 120d406355db362afe9777922000639d43f9d973 Mon Sep 17 00:00:00 2001 From: Maw-Fox Date: Fri, 11 Nov 2022 17:14:46 -0700 Subject: [PATCH] Implementation of validation/resize classes. --- trainer/diffusers_trainer.py | 165 ++++++++++++++++++++++++++++++----- 1 file changed, 143 insertions(+), 22 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 75a5afe..7cad170 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -26,6 +26,7 @@ import numpy as np import json import re import traceback +import shutil try: pynvml.nvmlInit() @@ -38,6 +39,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image, ImageOps +from PIL.Image import Image as Img from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -83,6 +85,10 @@ parser.add_argument('--clip_penultimate', type=str, default='False', help='Use p parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention') +parser.add_argument('--extended_validation', type=str, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') +parser.add_argument('--data_migration', type=str, default='True', help='Perform migration of resized images into a directory relative to the dataset path. Saves into `_cropped`.') +parser.add_argument('--skip_validation', type=str, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') + args = parser.parse_args() for arg in vars(args): @@ -149,39 +155,153 @@ def _sort_by_ratio(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float: return bucket[0] * bucket[1] +class Validation(): + def __init__(self, is_skipped: bool, is_extended: bool) -> None: + if is_skipped: + self.validate = self.__no_op + print("Validation: Skipped") + return + + if is_extended: + self.validate = self.__extended_validate + return print("Validation: Extended") + + self.validate = self.__validate + print("Validation: Standard") + + def completed(self) -> None: + self.validate = self.__no_op + return print('Validation complete. Skipping further validation.') + + def __validate(self, fp: str) -> bool: + try: + Image.open(fp) + return True + except: + print(f'WARNING: Image cannot be opened: {fp}') + return False + + def __extended_validate(self, fp: str) -> bool: + try: + Image.open(fp).load() + return True + except (OSError) as error: + if 'truncated' in str(error): + print(f'WARNING: Image truncated: {error}') + return False + print(f'WARNING: Image cannot be opened: {error}') + return False + except: + print(f'WARNING: Image cannot be opened: {error}') + return False + + def __no_op(self, fp: str) -> bool: + return True + +class Resize(): + def __init__(self, is_resizing: bool, is_migrating: bool) -> None: + if not is_resizing: + self.resize = self.__no_op + return + + if is_migrating: + self.resize = self.__migration + dataset_path = os.path.split(args.dataset) + self.__directory = os.path.join( + dataset_path[0], + f'{dataset_path[1]}_cropped' + ) + os.makedirs(self.__directory, exist_ok=True) + return print(f"Resizing: Performing migration to '{self.__directory}'.") + + self.resize = self.__no_migration + + def __no_migration(self, image_path: str, w: int, h: int) -> Img: + return ImageOps.fit( + Image.open(image_path), + (w, h), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ).convert(mode='RGB') + + def __migration(self, image_path: str, w: int, h: int) -> Img: + filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1]) + + image = ImageOps.fit( + Image.open(image_path), + (w, h), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ).convert(mode='RGB') + + image.save( + os.path.join(f'{self.__directory}', f'{filename}.jpg'), + optimize=True + ) + + try: + shutil.copy( + os.path.join(args.dataset, f'{filename}.txt'), + os.path.join(self.__directory, f'{filename}.txt'), + follow_symlinks=False + ) + except (FileNotFoundError): + f = open( + os.path.join(self.__directory, f'{filename}.txt'), + 'w', + encoding='UTF-8' + ) + f.close() + + return image + + def __no_op(self, image_path: str, w: int, h: int) -> Img: + return Image.open(image_path) + class ImageStore: def __init__(self, data_dir: str) -> None: self.data_dir = data_dir self.image_files = [] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] + + self.validator = Validation( + args.skip_validation, + args.extended_validation + ) + + self.resizer = Resize(args.resize, args.data_migration) + self.image_files = [x for x in self.image_files if self.__valid_file(x)] def __len__(self) -> int: return len(self.image_files) def __valid_file(self, f) -> bool: - try: - Image.open(f) - return True - except: - print(f'WARNING: Unable to open file: {f}') - return False + return self.validator.validate(f) + + # iterator returns images as PIL images and their index in the store - def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: + def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]).convert(mode='RGB'), f + yield Image.open(self.image_files[f]), f # get image by index - def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: - return Image.open(self.image_files[ref[0]]).convert(mode='RGB') + def get_image(self, ref: Tuple[int, int, int]) -> Img: + return self.resizer.resize( + self.image_files[ref[0]], + ref[1], + ref[2] + ) # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, ref: Tuple[int, int, int]) -> str: - filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' - with open(filename, 'r', encoding='UTF-8') as f: - return f.read() + #filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' + #with open(filename, 'r', encoding='UTF-8') as f: + return '' # ====================================== # @@ -403,15 +523,6 @@ class AspectDataset(torch.utils.data.Dataset): image_file = self.store.get_image(item) - if args.resize: - image_file = ImageOps.fit( - image_file, - (item[1], item[2]), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ) - return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -616,6 +727,16 @@ def main(): collate_fn=dataset.collate_fn ) + # Validate dataset and perform possible migration + for _, batch in enumerate(train_dataloader): + continue + + store.validator.completed() + + if args.resize and args.migration: + print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") + exit(0) + weight_dtype = torch.float16 if args.fp16 else torch.float32 # move models to device