Add variable aspect ratio during training
This commit is contained in:
parent
c690d005bc
commit
b080f33115
|
@ -74,7 +74,8 @@ data:
|
|||
tar_base: "links.tar"
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
size: 512
|
||||
max_size: 768
|
||||
resize: false
|
||||
flip_p: 0.5
|
||||
image_key: "image"
|
||||
copyright_rate: 0.9
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from inspect import trace
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
@ -10,7 +11,7 @@ import tarfile
|
|||
import glob
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
# downloads URLs from JSON
|
||||
|
||||
|
@ -22,8 +23,27 @@ parser.add_argument('--file', '-f', type=str, required=False, default='links.jso
|
|||
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
|
||||
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
|
||||
parser.add_argument('--threads', '-p', required=False, default=16)
|
||||
parser.add_argument('--resize', '-r', required=False, default=768)
|
||||
args = parser.parse_args()
|
||||
|
||||
def resize_image(image: Image, max_size=(768,768)):
|
||||
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||
# resize to integer multiple of 64
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 64, (w, h))
|
||||
|
||||
ratio = w / h
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (w, h))
|
||||
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||
|
||||
return res
|
||||
|
||||
class DownloadManager():
|
||||
def __init__(self, max_threads: int = 32):
|
||||
self.failed_downloads = []
|
||||
|
@ -31,14 +51,16 @@ class DownloadManager():
|
|||
self.uuid = str(uuid.uuid1())
|
||||
|
||||
# args = (post_id, link, caption_data)
|
||||
def download(self, args):
|
||||
def download(self, args_thread):
|
||||
try:
|
||||
image = Image.open(requests.get(args[1], stream=True).raw).convert('RGB')
|
||||
image = Image.open(requests.get(args_thread[1], stream=True).raw).convert('RGB')
|
||||
if args.resize:
|
||||
image = resize_image(image, max_size=(args.resize, args.resize))
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, format='PNG')
|
||||
__key__ = '%07d' % int(args[0])
|
||||
__key__ = '%07d' % int(args_thread[0])
|
||||
image = image_bytes.getvalue()
|
||||
caption = str(json.dumps(args[2]))
|
||||
caption = str(json.dumps(args_thread[2]))
|
||||
|
||||
with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
|
||||
f.write(image)
|
||||
|
@ -46,8 +68,9 @@ class DownloadManager():
|
|||
f.write(caption)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.failed_downloads.append((args[0], args[1], args[2]))
|
||||
import traceback
|
||||
print(e, traceback.print_exc())
|
||||
self.failed_downloads.append((args_thread[0], args_thread[1], args_thread[2]))
|
||||
|
||||
def download_urls(self, file_path):
|
||||
with open(file_path) as f:
|
||||
|
|
|
@ -13,6 +13,7 @@ parser.add_argument('--danbooru_key', '-key', type=str, required=False)
|
|||
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
|
||||
parser.add_argument('--posts', '-p', required=False, type=int, default=10000)
|
||||
parser.add_argument('--output', '-o', required=False, default='links.json')
|
||||
parser.add_argument('--start_page', '-s', required=False, default=0, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
import re
|
||||
|
@ -49,7 +50,11 @@ class DanbooruScraper():
|
|||
print("Error: num_posts must be divisible by batch_size")
|
||||
return
|
||||
for i in tqdm(range(num_posts//batch_size)):
|
||||
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i)
|
||||
try:
|
||||
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i+args.start_page)
|
||||
except Exception as e:
|
||||
print(f'Skipping page {i} - {e}')
|
||||
continue
|
||||
if not urls:
|
||||
print(f'Empty results at {i}')
|
||||
break
|
||||
|
|
|
@ -21,7 +21,6 @@ class LocalBase(Dataset):
|
|||
shuffle=False,
|
||||
mode='train',
|
||||
val_split=64,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
import random
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||
|
@ -17,8 +17,26 @@ import re
|
|||
import json
|
||||
import io
|
||||
|
||||
def resize_image(image: Image, max_size=(768,768)):
|
||||
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||
# resize to integer multiple of 64
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 64, (w, h))
|
||||
|
||||
ratio = w / h
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (w, h))
|
||||
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||
|
||||
return res
|
||||
|
||||
class CaptionProcessor(object):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize):
|
||||
self.copyright_rate = copyright_rate
|
||||
self.character_rate = character_rate
|
||||
self.general_rate = general_rate
|
||||
|
@ -26,6 +44,8 @@ class CaptionProcessor(object):
|
|||
self.normalize = normalize
|
||||
self.caption_shuffle = caption_shuffle
|
||||
self.transforms = transforms
|
||||
self.max_size = max_size
|
||||
self.resize = resize
|
||||
|
||||
def clean(self, text: str):
|
||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||
|
@ -59,16 +79,9 @@ class CaptionProcessor(object):
|
|||
|
||||
# preprocess image
|
||||
image = sample['image']
|
||||
|
||||
image = Image.open(io.BytesIO(image))
|
||||
|
||||
img = np.array(image).astype(np.uint8)
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = img.shape[0], img.shape[1]
|
||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2]
|
||||
image = Image.fromarray(img)
|
||||
|
||||
if self.resize:
|
||||
image = resize_image(image, max_size=(self.max_size, self.max_size))
|
||||
image = self.transforms(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
@ -107,7 +120,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
|||
|
||||
class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
||||
test=None, num_workers=4, size=512, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True,
|
||||
test=None, num_workers=4, max_size=768, resize=False, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True, random_order=True,
|
||||
**kwargs):
|
||||
super().__init__(self)
|
||||
print(f'Setting tar base to {tar_base}')
|
||||
|
@ -117,7 +130,8 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
|||
self.train = train
|
||||
self.validation = validation
|
||||
self.test = test
|
||||
self.size = size
|
||||
self.max_size = max_size
|
||||
self.resize = resize
|
||||
self.flip_p = flip_p
|
||||
self.image_key = image_key
|
||||
self.copyright_rate = copyright_rate
|
||||
|
@ -126,16 +140,17 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
|||
self.artist_rate = artist_rate
|
||||
self.normalize = normalize
|
||||
self.caption_shuffle = caption_shuffle
|
||||
self.random_order = random_order
|
||||
|
||||
def make_loader(self, dataset_config, train=True):
|
||||
image_transforms = []
|
||||
image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||
image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||
|
||||
transform_dict = {}
|
||||
transform_dict.update({self.image_key: image_transforms})
|
||||
|
||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms)
|
||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize)
|
||||
|
||||
|
||||
tars = os.path.join(self.tar_base)
|
||||
|
|
Loading…
Reference in New Issue