From b080f33115d9f59879467199d411504105f42ec8 Mon Sep 17 00:00:00 2001 From: harubaru Date: Wed, 21 Sep 2022 04:47:39 -0700 Subject: [PATCH] Add variable aspect ratio during training --- .../v1-finetune-danbooru-8gpu.yaml | 3 +- danbooru_data/download.py | 37 ++++++++++++--- danbooru_data/scrape.py | 7 ++- ldm/data/local.py | 1 - ldm/data/localdanbooru.py | 45 ++++++++++++------- 5 files changed, 68 insertions(+), 25 deletions(-) diff --git a/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml b/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml index 22e6194..ff3b871 100644 --- a/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml +++ b/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml @@ -74,7 +74,8 @@ data: tar_base: "links.tar" batch_size: 1 num_workers: 1 - size: 512 + max_size: 768 + resize: false flip_p: 0.5 image_key: "image" copyright_rate: 0.9 diff --git a/danbooru_data/download.py b/danbooru_data/download.py index 11c9a00..af4b1d6 100644 --- a/danbooru_data/download.py +++ b/danbooru_data/download.py @@ -1,3 +1,4 @@ +from inspect import trace import os import json import requests @@ -10,7 +11,7 @@ import tarfile import glob import uuid -from PIL import Image +from PIL import Image, ImageOps # downloads URLs from JSON @@ -22,8 +23,27 @@ parser.add_argument('--file', '-f', type=str, required=False, default='links.jso parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar') parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296) parser.add_argument('--threads', '-p', required=False, default=16) +parser.add_argument('--resize', '-r', required=False, default=768) args = parser.parse_args() +def resize_image(image: Image, max_size=(768,768)): + image = ImageOps.contain(image, max_size, Image.LANCZOS) + # resize to integer multiple of 64 + w, h = image.size + w, h = map(lambda x: x - x % 64, (w, h)) + + ratio = w / h + src_ratio = image.width / image.height + + src_w = w if ratio > src_ratio else image.width * h // image.height + src_h = h if ratio <= src_ratio else image.height * w // image.width + + resized = image.resize((src_w, src_h), resample=Image.LANCZOS) + res = Image.new("RGB", (w, h)) + res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2)) + + return res + class DownloadManager(): def __init__(self, max_threads: int = 32): self.failed_downloads = [] @@ -31,14 +51,16 @@ class DownloadManager(): self.uuid = str(uuid.uuid1()) # args = (post_id, link, caption_data) - def download(self, args): + def download(self, args_thread): try: - image = Image.open(requests.get(args[1], stream=True).raw).convert('RGB') + image = Image.open(requests.get(args_thread[1], stream=True).raw).convert('RGB') + if args.resize: + image = resize_image(image, max_size=(args.resize, args.resize)) image_bytes = io.BytesIO() image.save(image_bytes, format='PNG') - __key__ = '%07d' % int(args[0]) + __key__ = '%07d' % int(args_thread[0]) image = image_bytes.getvalue() - caption = str(json.dumps(args[2])) + caption = str(json.dumps(args_thread[2])) with open(f'{self.uuid}/{__key__}.image', 'wb') as f: f.write(image) @@ -46,8 +68,9 @@ class DownloadManager(): f.write(caption) except Exception as e: - print(e) - self.failed_downloads.append((args[0], args[1], args[2])) + import traceback + print(e, traceback.print_exc()) + self.failed_downloads.append((args_thread[0], args_thread[1], args_thread[2])) def download_urls(self, file_path): with open(file_path) as f: diff --git a/danbooru_data/scrape.py b/danbooru_data/scrape.py index 07461dd..36f3780 100644 --- a/danbooru_data/scrape.py +++ b/danbooru_data/scrape.py @@ -13,6 +13,7 @@ parser.add_argument('--danbooru_key', '-key', type=str, required=False) parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month") parser.add_argument('--posts', '-p', required=False, type=int, default=10000) parser.add_argument('--output', '-o', required=False, default='links.json') +parser.add_argument('--start_page', '-s', required=False, default=0, type=int) args = parser.parse_args() import re @@ -49,7 +50,11 @@ class DanbooruScraper(): print("Error: num_posts must be divisible by batch_size") return for i in tqdm(range(num_posts//batch_size)): - urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i) + try: + urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i+args.start_page) + except Exception as e: + print(f'Skipping page {i} - {e}') + continue if not urls: print(f'Empty results at {i}') break diff --git a/ldm/data/local.py b/ldm/data/local.py index f530c6a..5328b3b 100644 --- a/ldm/data/local.py +++ b/ldm/data/local.py @@ -21,7 +21,6 @@ class LocalBase(Dataset): shuffle=False, mode='train', val_split=64, - ): super().__init__() diff --git a/ldm/data/localdanbooru.py b/ldm/data/localdanbooru.py index 7617057..7624646 100644 --- a/ldm/data/localdanbooru.py +++ b/ldm/data/localdanbooru.py @@ -1,7 +1,7 @@ import os import numpy as np import PIL -from PIL import Image +from PIL import Image, ImageOps import random PIL.Image.MAX_IMAGE_PIXELS = 933120000 @@ -17,8 +17,26 @@ import re import json import io +def resize_image(image: Image, max_size=(768,768)): + image = ImageOps.contain(image, max_size, Image.LANCZOS) + # resize to integer multiple of 64 + w, h = image.size + w, h = map(lambda x: x - x % 64, (w, h)) + + ratio = w / h + src_ratio = image.width / image.height + + src_w = w if ratio > src_ratio else image.width * h // image.height + src_h = h if ratio <= src_ratio else image.height * w // image.width + + resized = image.resize((src_w, src_h), resample=Image.LANCZOS) + res = Image.new("RGB", (w, h)) + res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2)) + + return res + class CaptionProcessor(object): - def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms): + def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize): self.copyright_rate = copyright_rate self.character_rate = character_rate self.general_rate = general_rate @@ -26,6 +44,8 @@ class CaptionProcessor(object): self.normalize = normalize self.caption_shuffle = caption_shuffle self.transforms = transforms + self.max_size = max_size + self.resize = resize def clean(self, text: str): text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip() @@ -59,16 +79,9 @@ class CaptionProcessor(object): # preprocess image image = sample['image'] - image = Image.open(io.BytesIO(image)) - - img = np.array(image).astype(np.uint8) - crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] - image = Image.fromarray(img) - + if self.resize: + image = resize_image(image, max_size=(self.max_size, self.max_size)) image = self.transforms(image) image = np.array(image).astype(np.uint8) sample['image'] = (image / 127.5 - 1.0).astype(np.float32) @@ -107,7 +120,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): class DanbooruWebDataModuleFromConfig(pl.LightningDataModule): def __init__(self, tar_base, batch_size, train=None, validation=None, - test=None, num_workers=4, size=512, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True, + test=None, num_workers=4, max_size=768, resize=False, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True, random_order=True, **kwargs): super().__init__(self) print(f'Setting tar base to {tar_base}') @@ -117,7 +130,8 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule): self.train = train self.validation = validation self.test = test - self.size = size + self.max_size = max_size + self.resize = resize self.flip_p = flip_p self.image_key = image_key self.copyright_rate = copyright_rate @@ -126,16 +140,17 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule): self.artist_rate = artist_rate self.normalize = normalize self.caption_shuffle = caption_shuffle + self.random_order = random_order def make_loader(self, dataset_config, train=True): image_transforms = [] - image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],) + image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(self.flip_p)],) image_transforms = torchvision.transforms.Compose(image_transforms) transform_dict = {} transform_dict.update({self.image_key: image_transforms}) - postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms) + postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize) tars = os.path.join(self.tar_base)