Add variable aspect ratio during training
This commit is contained in:
parent
c690d005bc
commit
b080f33115
|
@ -74,7 +74,8 @@ data:
|
||||||
tar_base: "links.tar"
|
tar_base: "links.tar"
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
size: 512
|
max_size: 768
|
||||||
|
resize: false
|
||||||
flip_p: 0.5
|
flip_p: 0.5
|
||||||
image_key: "image"
|
image_key: "image"
|
||||||
copyright_rate: 0.9
|
copyright_rate: 0.9
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from inspect import trace
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
@ -10,7 +11,7 @@ import tarfile
|
||||||
import glob
|
import glob
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
# downloads URLs from JSON
|
# downloads URLs from JSON
|
||||||
|
|
||||||
|
@ -22,8 +23,27 @@ parser.add_argument('--file', '-f', type=str, required=False, default='links.jso
|
||||||
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
|
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
|
||||||
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
|
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
|
||||||
parser.add_argument('--threads', '-p', required=False, default=16)
|
parser.add_argument('--threads', '-p', required=False, default=16)
|
||||||
|
parser.add_argument('--resize', '-r', required=False, default=768)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def resize_image(image: Image, max_size=(768,768)):
|
||||||
|
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||||
|
# resize to integer multiple of 64
|
||||||
|
w, h = image.size
|
||||||
|
w, h = map(lambda x: x - x % 64, (w, h))
|
||||||
|
|
||||||
|
ratio = w / h
|
||||||
|
src_ratio = image.width / image.height
|
||||||
|
|
||||||
|
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||||
|
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||||
|
|
||||||
|
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||||
|
res = Image.new("RGB", (w, h))
|
||||||
|
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
class DownloadManager():
|
class DownloadManager():
|
||||||
def __init__(self, max_threads: int = 32):
|
def __init__(self, max_threads: int = 32):
|
||||||
self.failed_downloads = []
|
self.failed_downloads = []
|
||||||
|
@ -31,14 +51,16 @@ class DownloadManager():
|
||||||
self.uuid = str(uuid.uuid1())
|
self.uuid = str(uuid.uuid1())
|
||||||
|
|
||||||
# args = (post_id, link, caption_data)
|
# args = (post_id, link, caption_data)
|
||||||
def download(self, args):
|
def download(self, args_thread):
|
||||||
try:
|
try:
|
||||||
image = Image.open(requests.get(args[1], stream=True).raw).convert('RGB')
|
image = Image.open(requests.get(args_thread[1], stream=True).raw).convert('RGB')
|
||||||
|
if args.resize:
|
||||||
|
image = resize_image(image, max_size=(args.resize, args.resize))
|
||||||
image_bytes = io.BytesIO()
|
image_bytes = io.BytesIO()
|
||||||
image.save(image_bytes, format='PNG')
|
image.save(image_bytes, format='PNG')
|
||||||
__key__ = '%07d' % int(args[0])
|
__key__ = '%07d' % int(args_thread[0])
|
||||||
image = image_bytes.getvalue()
|
image = image_bytes.getvalue()
|
||||||
caption = str(json.dumps(args[2]))
|
caption = str(json.dumps(args_thread[2]))
|
||||||
|
|
||||||
with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
|
with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
|
||||||
f.write(image)
|
f.write(image)
|
||||||
|
@ -46,8 +68,9 @@ class DownloadManager():
|
||||||
f.write(caption)
|
f.write(caption)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
import traceback
|
||||||
self.failed_downloads.append((args[0], args[1], args[2]))
|
print(e, traceback.print_exc())
|
||||||
|
self.failed_downloads.append((args_thread[0], args_thread[1], args_thread[2]))
|
||||||
|
|
||||||
def download_urls(self, file_path):
|
def download_urls(self, file_path):
|
||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
|
|
|
@ -13,6 +13,7 @@ parser.add_argument('--danbooru_key', '-key', type=str, required=False)
|
||||||
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
|
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
|
||||||
parser.add_argument('--posts', '-p', required=False, type=int, default=10000)
|
parser.add_argument('--posts', '-p', required=False, type=int, default=10000)
|
||||||
parser.add_argument('--output', '-o', required=False, default='links.json')
|
parser.add_argument('--output', '-o', required=False, default='links.json')
|
||||||
|
parser.add_argument('--start_page', '-s', required=False, default=0, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
@ -49,7 +50,11 @@ class DanbooruScraper():
|
||||||
print("Error: num_posts must be divisible by batch_size")
|
print("Error: num_posts must be divisible by batch_size")
|
||||||
return
|
return
|
||||||
for i in tqdm(range(num_posts//batch_size)):
|
for i in tqdm(range(num_posts//batch_size)):
|
||||||
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i)
|
try:
|
||||||
|
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i+args.start_page)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Skipping page {i} - {e}')
|
||||||
|
continue
|
||||||
if not urls:
|
if not urls:
|
||||||
print(f'Empty results at {i}')
|
print(f'Empty results at {i}')
|
||||||
break
|
break
|
||||||
|
|
|
@ -21,7 +21,6 @@ class LocalBase(Dataset):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
mode='train',
|
mode='train',
|
||||||
val_split=64,
|
val_split=64,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
import random
|
import random
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||||
|
@ -17,8 +17,26 @@ import re
|
||||||
import json
|
import json
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
def resize_image(image: Image, max_size=(768,768)):
|
||||||
|
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||||
|
# resize to integer multiple of 64
|
||||||
|
w, h = image.size
|
||||||
|
w, h = map(lambda x: x - x % 64, (w, h))
|
||||||
|
|
||||||
|
ratio = w / h
|
||||||
|
src_ratio = image.width / image.height
|
||||||
|
|
||||||
|
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||||
|
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||||
|
|
||||||
|
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||||
|
res = Image.new("RGB", (w, h))
|
||||||
|
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
class CaptionProcessor(object):
|
class CaptionProcessor(object):
|
||||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms):
|
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize):
|
||||||
self.copyright_rate = copyright_rate
|
self.copyright_rate = copyright_rate
|
||||||
self.character_rate = character_rate
|
self.character_rate = character_rate
|
||||||
self.general_rate = general_rate
|
self.general_rate = general_rate
|
||||||
|
@ -26,6 +44,8 @@ class CaptionProcessor(object):
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.caption_shuffle = caption_shuffle
|
self.caption_shuffle = caption_shuffle
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
|
self.max_size = max_size
|
||||||
|
self.resize = resize
|
||||||
|
|
||||||
def clean(self, text: str):
|
def clean(self, text: str):
|
||||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||||
|
@ -59,16 +79,9 @@ class CaptionProcessor(object):
|
||||||
|
|
||||||
# preprocess image
|
# preprocess image
|
||||||
image = sample['image']
|
image = sample['image']
|
||||||
|
|
||||||
image = Image.open(io.BytesIO(image))
|
image = Image.open(io.BytesIO(image))
|
||||||
|
if self.resize:
|
||||||
img = np.array(image).astype(np.uint8)
|
image = resize_image(image, max_size=(self.max_size, self.max_size))
|
||||||
crop = min(img.shape[0], img.shape[1])
|
|
||||||
h, w, = img.shape[0], img.shape[1]
|
|
||||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
|
||||||
(w - crop) // 2:(w + crop) // 2]
|
|
||||||
image = Image.fromarray(img)
|
|
||||||
|
|
||||||
image = self.transforms(image)
|
image = self.transforms(image)
|
||||||
image = np.array(image).astype(np.uint8)
|
image = np.array(image).astype(np.uint8)
|
||||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
@ -107,7 +120,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||||
|
|
||||||
class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
||||||
test=None, num_workers=4, size=512, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True,
|
test=None, num_workers=4, max_size=768, resize=False, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True, random_order=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(self)
|
super().__init__(self)
|
||||||
print(f'Setting tar base to {tar_base}')
|
print(f'Setting tar base to {tar_base}')
|
||||||
|
@ -117,7 +130,8 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
self.train = train
|
self.train = train
|
||||||
self.validation = validation
|
self.validation = validation
|
||||||
self.test = test
|
self.test = test
|
||||||
self.size = size
|
self.max_size = max_size
|
||||||
|
self.resize = resize
|
||||||
self.flip_p = flip_p
|
self.flip_p = flip_p
|
||||||
self.image_key = image_key
|
self.image_key = image_key
|
||||||
self.copyright_rate = copyright_rate
|
self.copyright_rate = copyright_rate
|
||||||
|
@ -126,16 +140,17 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
self.artist_rate = artist_rate
|
self.artist_rate = artist_rate
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.caption_shuffle = caption_shuffle
|
self.caption_shuffle = caption_shuffle
|
||||||
|
self.random_order = random_order
|
||||||
|
|
||||||
def make_loader(self, dataset_config, train=True):
|
def make_loader(self, dataset_config, train=True):
|
||||||
image_transforms = []
|
image_transforms = []
|
||||||
image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||||
image_transforms = torchvision.transforms.Compose(image_transforms)
|
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||||
|
|
||||||
transform_dict = {}
|
transform_dict = {}
|
||||||
transform_dict.update({self.image_key: image_transforms})
|
transform_dict.update({self.image_key: image_transforms})
|
||||||
|
|
||||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms)
|
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms, max_size=self.max_size, resize=self.resize)
|
||||||
|
|
||||||
|
|
||||||
tars = os.path.join(self.tar_base)
|
tars = os.path.join(self.tar_base)
|
||||||
|
|
Loading…
Reference in New Issue