Implementation of validation/resize classes.

This commit is contained in:
Maw-Fox 2022-11-11 17:14:46 -07:00
parent 624f0f14af
commit 120d406355
1 changed files with 143 additions and 22 deletions

View File

@ -26,6 +26,7 @@ import numpy as np
import json
import re
import traceback
import shutil
try:
pynvml.nvmlInit()
@ -38,6 +39,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image, ImageOps
from PIL.Image import Image as Img
from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d
@ -83,6 +85,10 @@ parser.add_argument('--clip_penultimate', type=str, default='False', help='Use p
parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits')
parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention')
parser.add_argument('--extended_validation', type=str, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
parser.add_argument('--data_migration', type=str, default='True', help='Perform migration of resized images into a directory relative to the dataset path. Saves into `<dataset_directory_name>_cropped`.')
parser.add_argument('--skip_validation', type=str, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
args = parser.parse_args()
for arg in vars(args):
@ -149,39 +155,153 @@ def _sort_by_ratio(bucket: tuple) -> float:
def _sort_by_area(bucket: tuple) -> float:
return bucket[0] * bucket[1]
class Validation():
def __init__(self, is_skipped: bool, is_extended: bool) -> None:
if is_skipped:
self.validate = self.__no_op
print("Validation: Skipped")
return
if is_extended:
self.validate = self.__extended_validate
return print("Validation: Extended")
self.validate = self.__validate
print("Validation: Standard")
def completed(self) -> None:
self.validate = self.__no_op
return print('Validation complete. Skipping further validation.')
def __validate(self, fp: str) -> bool:
try:
Image.open(fp)
return True
except:
print(f'WARNING: Image cannot be opened: {fp}')
return False
def __extended_validate(self, fp: str) -> bool:
try:
Image.open(fp).load()
return True
except (OSError) as error:
if 'truncated' in str(error):
print(f'WARNING: Image truncated: {error}')
return False
print(f'WARNING: Image cannot be opened: {error}')
return False
except:
print(f'WARNING: Image cannot be opened: {error}')
return False
def __no_op(self, fp: str) -> bool:
return True
class Resize():
def __init__(self, is_resizing: bool, is_migrating: bool) -> None:
if not is_resizing:
self.resize = self.__no_op
return
if is_migrating:
self.resize = self.__migration
dataset_path = os.path.split(args.dataset)
self.__directory = os.path.join(
dataset_path[0],
f'{dataset_path[1]}_cropped'
)
os.makedirs(self.__directory, exist_ok=True)
return print(f"Resizing: Performing migration to '{self.__directory}'.")
self.resize = self.__no_migration
def __no_migration(self, image_path: str, w: int, h: int) -> Img:
return ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')
def __migration(self, image_path: str, w: int, h: int) -> Img:
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1])
image = ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')
image.save(
os.path.join(f'{self.__directory}', f'{filename}.jpg'),
optimize=True
)
try:
shutil.copy(
os.path.join(args.dataset, f'{filename}.txt'),
os.path.join(self.__directory, f'{filename}.txt'),
follow_symlinks=False
)
except (FileNotFoundError):
f = open(
os.path.join(self.__directory, f'{filename}.txt'),
'w',
encoding='UTF-8'
)
f.close()
return image
def __no_op(self, image_path: str, w: int, h: int) -> Img:
return Image.open(image_path)
class ImageStore:
def __init__(self, data_dir: str) -> None:
self.data_dir = data_dir
self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.validator = Validation(
args.skip_validation,
args.extended_validation
)
self.resizer = Resize(args.resize, args.data_migration)
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
def __len__(self) -> int:
return len(self.image_files)
def __valid_file(self, f) -> bool:
try:
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
return self.validator.validate(f)
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
yield Image.open(self.image_files[f]), f
# get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
def get_image(self, ref: Tuple[int, int, int]) -> Img:
return self.resizer.resize(
self.image_files[ref[0]],
ref[1],
ref[2]
)
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, ref: Tuple[int, int, int]) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
with open(filename, 'r', encoding='UTF-8') as f:
return f.read()
#filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
#with open(filename, 'r', encoding='UTF-8') as f:
return ''
# ====================================== #
@ -403,15 +523,6 @@ class AspectDataset(torch.utils.data.Dataset):
image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg:
caption_file = self.store.get_caption(item)
@ -616,6 +727,16 @@ def main():
collate_fn=dataset.collate_fn
)
# Validate dataset and perform possible migration
for _, batch in enumerate(train_dataloader):
continue
store.validator.completed()
if args.resize and args.migration:
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0)
weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device