Implementation of validation/resize classes.

This commit is contained in:
Maw-Fox 2022-11-11 17:14:46 -07:00
parent 624f0f14af
commit 120d406355
1 changed files with 143 additions and 22 deletions

View File

@ -26,6 +26,7 @@ import numpy as np
import json import json
import re import re
import traceback import traceback
import shutil
try: try:
pynvml.nvmlInit() pynvml.nvmlInit()
@ -38,6 +39,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.Image import Image as Img
from typing import Dict, List, Generator, Tuple from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
@ -83,6 +85,10 @@ parser.add_argument('--clip_penultimate', type=str, default='False', help='Use p
parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits') parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits')
parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention') parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention')
parser.add_argument('--extended_validation', type=str, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
parser.add_argument('--data_migration', type=str, default='True', help='Perform migration of resized images into a directory relative to the dataset path. Saves into `<dataset_directory_name>_cropped`.')
parser.add_argument('--skip_validation', type=str, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
args = parser.parse_args() args = parser.parse_args()
for arg in vars(args): for arg in vars(args):
@ -149,39 +155,153 @@ def _sort_by_ratio(bucket: tuple) -> float:
def _sort_by_area(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float:
return bucket[0] * bucket[1] return bucket[0] * bucket[1]
class Validation():
def __init__(self, is_skipped: bool, is_extended: bool) -> None:
if is_skipped:
self.validate = self.__no_op
print("Validation: Skipped")
return
if is_extended:
self.validate = self.__extended_validate
return print("Validation: Extended")
self.validate = self.__validate
print("Validation: Standard")
def completed(self) -> None:
self.validate = self.__no_op
return print('Validation complete. Skipping further validation.')
def __validate(self, fp: str) -> bool:
try:
Image.open(fp)
return True
except:
print(f'WARNING: Image cannot be opened: {fp}')
return False
def __extended_validate(self, fp: str) -> bool:
try:
Image.open(fp).load()
return True
except (OSError) as error:
if 'truncated' in str(error):
print(f'WARNING: Image truncated: {error}')
return False
print(f'WARNING: Image cannot be opened: {error}')
return False
except:
print(f'WARNING: Image cannot be opened: {error}')
return False
def __no_op(self, fp: str) -> bool:
return True
class Resize():
def __init__(self, is_resizing: bool, is_migrating: bool) -> None:
if not is_resizing:
self.resize = self.__no_op
return
if is_migrating:
self.resize = self.__migration
dataset_path = os.path.split(args.dataset)
self.__directory = os.path.join(
dataset_path[0],
f'{dataset_path[1]}_cropped'
)
os.makedirs(self.__directory, exist_ok=True)
return print(f"Resizing: Performing migration to '{self.__directory}'.")
self.resize = self.__no_migration
def __no_migration(self, image_path: str, w: int, h: int) -> Img:
return ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')
def __migration(self, image_path: str, w: int, h: int) -> Img:
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1])
image = ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')
image.save(
os.path.join(f'{self.__directory}', f'{filename}.jpg'),
optimize=True
)
try:
shutil.copy(
os.path.join(args.dataset, f'{filename}.txt'),
os.path.join(self.__directory, f'{filename}.txt'),
follow_symlinks=False
)
except (FileNotFoundError):
f = open(
os.path.join(self.__directory, f'{filename}.txt'),
'w',
encoding='UTF-8'
)
f.close()
return image
def __no_op(self, image_path: str, w: int, h: int) -> Img:
return Image.open(image_path)
class ImageStore: class ImageStore:
def __init__(self, data_dir: str) -> None: def __init__(self, data_dir: str) -> None:
self.data_dir = data_dir self.data_dir = data_dir
self.image_files = [] self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.validator = Validation(
args.skip_validation,
args.extended_validation
)
self.resizer = Resize(args.resize, args.data_migration)
self.image_files = [x for x in self.image_files if self.__valid_file(x)] self.image_files = [x for x in self.image_files if self.__valid_file(x)]
def __len__(self) -> int: def __len__(self) -> int:
return len(self.image_files) return len(self.image_files)
def __valid_file(self, f) -> bool: def __valid_file(self, f) -> bool:
try: return self.validator.validate(f)
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
# iterator returns images as PIL images and their index in the store # iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]:
for f in range(len(self)): for f in range(len(self)):
yield Image.open(self.image_files[f]).convert(mode='RGB'), f yield Image.open(self.image_files[f]), f
# get image by index # get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: def get_image(self, ref: Tuple[int, int, int]) -> Img:
return Image.open(self.image_files[ref[0]]).convert(mode='RGB') return self.resizer.resize(
self.image_files[ref[0]],
ref[1],
ref[2]
)
# gets caption by removing the extension from the filename and replacing it with .txt # gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, ref: Tuple[int, int, int]) -> str: def get_caption(self, ref: Tuple[int, int, int]) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' #filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
with open(filename, 'r', encoding='UTF-8') as f: #with open(filename, 'r', encoding='UTF-8') as f:
return f.read() return ''
# ====================================== # # ====================================== #
@ -403,15 +523,6 @@ class AspectDataset(torch.utils.data.Dataset):
image_file = self.store.get_image(item) image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file) return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg: if random.random() > self.ucg:
caption_file = self.store.get_caption(item) caption_file = self.store.get_caption(item)
@ -616,6 +727,16 @@ def main():
collate_fn=dataset.collate_fn collate_fn=dataset.collate_fn
) )
# Validate dataset and perform possible migration
for _, batch in enumerate(train_dataloader):
continue
store.validator.completed()
if args.resize and args.migration:
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0)
weight_dtype = torch.float16 if args.fp16 else torch.float32 weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device # move models to device