Add variable aspect ratio during training

This commit is contained in:
harubaru 2022-09-21 04:47:39 -07:00
parent c690d005bc
commit b080f33115
5 changed files with 68 additions and 25 deletions

View File

@ -74,7 +74,8 @@ data:
tar_base: "links.tar" tar_base: "links.tar"
batch_size: 1 batch_size: 1
num_workers: 1 num_workers: 1
size: 512 max_size: 768
resize: false
flip_p: 0.5 flip_p: 0.5
image_key: "image" image_key: "image"
copyright_rate: 0.9 copyright_rate: 0.9

View File

@ -1,3 +1,4 @@
from inspect import trace
import os import os
import json import json
import requests import requests
@ -10,7 +11,7 @@ import tarfile
import glob import glob
import uuid import uuid
from PIL import Image from PIL import Image, ImageOps
# downloads URLs from JSON # downloads URLs from JSON
@ -22,8 +23,27 @@ parser.add_argument('--file', '-f', type=str, required=False, default='links.jso
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar') parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296) parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
parser.add_argument('--threads', '-p', required=False, default=16) parser.add_argument('--threads', '-p', required=False, default=16)
parser.add_argument('--resize', '-r', required=False, default=768)
args = parser.parse_args() args = parser.parse_args()
def resize_image(image: Image, max_size=(768,768)):
image = ImageOps.contain(image, max_size, Image.LANCZOS)
# resize to integer multiple of 64
w, h = image.size
w, h = map(lambda x: x - x % 64, (w, h))
ratio = w / h
src_ratio = image.width / image.height
src_w = w if ratio > src_ratio else image.width * h // image.height
src_h = h if ratio <= src_ratio else image.height * w // image.width
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (w, h))
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
return res
class DownloadManager(): class DownloadManager():
def __init__(self, max_threads: int = 32): def __init__(self, max_threads: int = 32):
self.failed_downloads = [] self.failed_downloads = []
@ -31,14 +51,16 @@ class DownloadManager():
self.uuid = str(uuid.uuid1()) self.uuid = str(uuid.uuid1())
# args = (post_id, link, caption_data) # args = (post_id, link, caption_data)
def download(self, args): def download(self, args_thread):
try: try:
image = Image.open(requests.get(args[1], stream=True).raw).convert('RGB') image = Image.open(requests.get(args_thread[1], stream=True).raw).convert('RGB')
if args.resize:
image = resize_image(image, max_size=(args.resize, args.resize))
image_bytes = io.BytesIO() image_bytes = io.BytesIO()
image.save(image_bytes, format='PNG') image.save(image_bytes, format='PNG')
__key__ = '%07d' % int(args[0]) __key__ = '%07d' % int(args_thread[0])
image = image_bytes.getvalue() image = image_bytes.getvalue()
caption = str(json.dumps(args[2])) caption = str(json.dumps(args_thread[2]))
with open(f'{self.uuid}/{__key__}.image', 'wb') as f: with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
f.write(image) f.write(image)
@ -46,8 +68,9 @@ class DownloadManager():
f.write(caption) f.write(caption)
except Exception as e: except Exception as e:
print(e) import traceback
self.failed_downloads.append((args[0], args[1], args[2])) print(e, traceback.print_exc())
self.failed_downloads.append((args_thread[0], args_thread[1], args_thread[2]))
def download_urls(self, file_path): def download_urls(self, file_path):
with open(file_path) as f: with open(file_path) as f:

View File

@ -13,6 +13,7 @@ parser.add_argument('--danbooru_key', '-key', type=str, required=False)
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month") parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
parser.add_argument('--posts', '-p', required=False, type=int, default=10000) parser.add_argument('--posts', '-p', required=False, type=int, default=10000)
parser.add_argument('--output', '-o', required=False, default='links.json') parser.add_argument('--output', '-o', required=False, default='links.json')
parser.add_argument('--start_page', '-s', required=False, default=0, type=int)
args = parser.parse_args() args = parser.parse_args()
import re import re
@ -49,7 +50,11 @@ class DanbooruScraper():
print("Error: num_posts must be divisible by batch_size") print("Error: num_posts must be divisible by batch_size")
return return
for i in tqdm(range(num_posts//batch_size)): for i in tqdm(range(num_posts//batch_size)):
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i) try:
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i+args.start_page)
except Exception as e:
print(f'Skipping page {i} - {e}')
continue
if not urls: if not urls:
print(f'Empty results at {i}') print(f'Empty results at {i}')
break break

View File

@ -21,7 +21,6 @@ class LocalBase(Dataset):
shuffle=False, shuffle=False,
mode='train', mode='train',
val_split=64, val_split=64,
): ):
super().__init__() super().__init__()

View File

@ -1,7 +1,7 @@
import os import os
import numpy as np import numpy as np
import PIL import PIL
from PIL import Image from PIL import Image, ImageOps
import random import random
PIL.Image.MAX_IMAGE_PIXELS = 933120000 PIL.Image.MAX_IMAGE_PIXELS = 933120000
@ -17,8 +17,26 @@ import re
import json import json
import io import io
def resize_image(image: Image, max_size=(768,768)):
image = ImageOps.contain(image, max_size, Image.LANCZOS)
# resize to integer multiple of 64
w, h = image.size
w, h = map(lambda x: x - x % 64, (w, h))
ratio = w / h
src_ratio = image.width / image.height
src_w = w if ratio > src_ratio else image.width * h // image.height
src_h = h if ratio <= src_ratio else image.height * w // image.width
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (w, h))
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
return res
class CaptionProcessor(object): class CaptionProcessor(object):
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms): def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize):
self.copyright_rate = copyright_rate self.copyright_rate = copyright_rate
self.character_rate = character_rate self.character_rate = character_rate
self.general_rate = general_rate self.general_rate = general_rate
@ -26,6 +44,8 @@ class CaptionProcessor(object):
self.normalize = normalize self.normalize = normalize
self.caption_shuffle = caption_shuffle self.caption_shuffle = caption_shuffle
self.transforms = transforms self.transforms = transforms
self.max_size = max_size
self.resize = resize
def clean(self, text: str): def clean(self, text: str):
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip() text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
@ -59,16 +79,9 @@ class CaptionProcessor(object):
# preprocess image # preprocess image
image = sample['image'] image = sample['image']
image = Image.open(io.BytesIO(image)) image = Image.open(io.BytesIO(image))
if self.resize:
img = np.array(image).astype(np.uint8) image = resize_image(image, max_size=(self.max_size, self.max_size))
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
image = self.transforms(image) image = self.transforms(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
sample['image'] = (image / 127.5 - 1.0).astype(np.float32) sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
@ -107,7 +120,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
class DanbooruWebDataModuleFromConfig(pl.LightningDataModule): class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
def __init__(self, tar_base, batch_size, train=None, validation=None, def __init__(self, tar_base, batch_size, train=None, validation=None,
test=None, num_workers=4, size=512, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True, test=None, num_workers=4, max_size=768, resize=False, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True, random_order=True,
**kwargs): **kwargs):
super().__init__(self) super().__init__(self)
print(f'Setting tar base to {tar_base}') print(f'Setting tar base to {tar_base}')
@ -117,7 +130,8 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
self.train = train self.train = train
self.validation = validation self.validation = validation
self.test = test self.test = test
self.size = size self.max_size = max_size
self.resize = resize
self.flip_p = flip_p self.flip_p = flip_p
self.image_key = image_key self.image_key = image_key
self.copyright_rate = copyright_rate self.copyright_rate = copyright_rate
@ -126,16 +140,17 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
self.artist_rate = artist_rate self.artist_rate = artist_rate
self.normalize = normalize self.normalize = normalize
self.caption_shuffle = caption_shuffle self.caption_shuffle = caption_shuffle
self.random_order = random_order
def make_loader(self, dataset_config, train=True): def make_loader(self, dataset_config, train=True):
image_transforms = [] image_transforms = []
image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],) image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
image_transforms = torchvision.transforms.Compose(image_transforms) image_transforms = torchvision.transforms.Compose(image_transforms)
transform_dict = {} transform_dict = {}
transform_dict.update({self.image_key: image_transforms}) transform_dict.update({self.image_key: image_transforms})
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms) postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize)
tars = os.path.join(self.tar_base) tars = os.path.join(self.tar_base)