2022-08-30 15:06:47 -06:00
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import PIL
|
|
|
|
from PIL import Image
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
import glob
|
|
|
|
|
|
|
|
import random
|
|
|
|
|
2022-09-02 22:00:05 -06:00
|
|
|
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
2022-08-30 15:06:47 -06:00
|
|
|
|
2022-09-22 15:56:27 -06:00
|
|
|
import torchvision
|
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import re
|
|
|
|
import json
|
|
|
|
import io
|
|
|
|
|
|
|
|
def resize_image(image: Image, max_size=(768,768)):
|
|
|
|
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
|
|
|
# resize to integer multiple of 64
|
|
|
|
w, h = image.size
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h))
|
|
|
|
|
|
|
|
ratio = w / h
|
|
|
|
src_ratio = image.width / image.height
|
|
|
|
|
|
|
|
src_w = w if ratio > src_ratio else image.width * h // image.height
|
|
|
|
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
|
|
|
|
|
|
|
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
|
|
|
res = Image.new("RGB", (w, h))
|
|
|
|
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
class CaptionProcessor(object):
|
|
|
|
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order):
|
|
|
|
self.copyright_rate = copyright_rate
|
|
|
|
self.character_rate = character_rate
|
|
|
|
self.general_rate = general_rate
|
|
|
|
self.artist_rate = artist_rate
|
|
|
|
self.normalize = normalize
|
|
|
|
self.caption_shuffle = caption_shuffle
|
|
|
|
self.transforms = transforms
|
|
|
|
self.max_size = max_size
|
|
|
|
self.resize = resize
|
|
|
|
self.random_order = random_order
|
|
|
|
|
|
|
|
def clean(self, text: str):
|
|
|
|
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
|
|
|
if self.caption_shuffle:
|
|
|
|
text = text.split(' ')
|
|
|
|
random.shuffle(text)
|
|
|
|
text = ' '.join(text)
|
|
|
|
if self.normalize:
|
|
|
|
text = ', '.join([i.replace('_', ' ') for i in text.split(' ')]).lstrip(', ').rstrip(', ')
|
|
|
|
return text
|
|
|
|
|
|
|
|
def get_key(self, val_dict, key, clean_val = True, cond_drop = 0.0, prepend_space = False, append_comma = False):
|
|
|
|
space = ' ' if prepend_space else ''
|
|
|
|
comma = ',' if append_comma else ''
|
|
|
|
if random.random() < cond_drop:
|
|
|
|
if (key in val_dict) and val_dict[key]:
|
|
|
|
if clean_val:
|
|
|
|
return space + self.clean(val_dict[key]) + comma
|
|
|
|
else:
|
|
|
|
return space + val_dict[key] + comma
|
|
|
|
return ''
|
|
|
|
|
|
|
|
def __call__(self, sample):
|
|
|
|
# preprocess caption
|
|
|
|
caption_data = json.loads(sample['caption'])
|
|
|
|
if not self.random_order:
|
|
|
|
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
|
|
|
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
|
|
|
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
|
|
|
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
|
|
|
tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
|
|
|
else:
|
|
|
|
character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False)
|
|
|
|
copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False)
|
|
|
|
artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False)
|
|
|
|
general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False)
|
|
|
|
tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ')
|
|
|
|
sample['caption'] = tag_str
|
|
|
|
|
|
|
|
# preprocess image
|
|
|
|
image = sample['image']
|
|
|
|
image = Image.open(io.BytesIO(image))
|
|
|
|
if self.resize:
|
|
|
|
image = resize_image(image, max_size=(self.max_size, self.max_size))
|
|
|
|
image = self.transforms(image)
|
|
|
|
image = np.array(image).astype(np.uint8)
|
|
|
|
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
|
|
|
return sample
|
|
|
|
|
2022-08-30 15:06:47 -06:00
|
|
|
class LocalBase(Dataset):
|
|
|
|
def __init__(self,
|
2022-09-02 22:00:05 -06:00
|
|
|
data_root='./danbooru-aesthetic',
|
2022-09-22 15:56:27 -06:00
|
|
|
size=768,
|
2022-08-30 15:06:47 -06:00
|
|
|
interpolation="bicubic",
|
2022-09-02 22:00:05 -06:00
|
|
|
flip_p=0.5,
|
2022-09-07 00:33:12 -06:00
|
|
|
crop=True,
|
2022-09-02 22:00:05 -06:00
|
|
|
shuffle=False,
|
2022-09-07 00:33:12 -06:00
|
|
|
mode='train',
|
|
|
|
val_split=64,
|
2022-08-30 15:06:47 -06:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
2022-09-02 22:00:05 -06:00
|
|
|
self.shuffle=shuffle
|
2022-09-07 00:33:12 -06:00
|
|
|
self.crop = crop
|
2022-09-02 22:00:05 -06:00
|
|
|
|
2022-08-30 15:06:47 -06:00
|
|
|
print('Fetching data.')
|
|
|
|
|
|
|
|
ext = ['png', 'jpg', 'jpeg', 'bmp']
|
|
|
|
self.image_files = []
|
|
|
|
[self.image_files.extend(glob.glob(f'{data_root}/img/' + '*.' + e)) for e in ext]
|
2022-09-07 00:33:12 -06:00
|
|
|
if mode == 'val':
|
|
|
|
self.image_files = self.image_files[:len(self.image_files)//val_split]
|
2022-08-30 15:06:47 -06:00
|
|
|
|
|
|
|
print('Constructing image-caption map.')
|
|
|
|
|
|
|
|
self.examples = {}
|
|
|
|
self.hashes = []
|
|
|
|
for i in self.image_files:
|
|
|
|
hash = i[len(f'{data_root}/img/'):].split('.')[0]
|
|
|
|
self.examples[hash] = {
|
|
|
|
'image': i,
|
|
|
|
'text': f'{data_root}/txt/{hash}.txt'
|
|
|
|
}
|
|
|
|
self.hashes.append(hash)
|
|
|
|
|
|
|
|
print(f'image-caption map has {len(self.examples.keys())} examples')
|
|
|
|
|
|
|
|
self.size = size
|
|
|
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
|
|
|
"bilinear": PIL.Image.BILINEAR,
|
|
|
|
"bicubic": PIL.Image.BICUBIC,
|
|
|
|
"lanczos": PIL.Image.LANCZOS,
|
|
|
|
}[interpolation]
|
|
|
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
|
|
|
|
|
|
|
def random_sample(self):
|
|
|
|
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
|
|
|
|
|
|
|
def sequential_sample(self, i):
|
|
|
|
if i >= self.__len__() - 1:
|
|
|
|
return self.__getitem__(0)
|
|
|
|
return self.__getitem__(i + 1)
|
|
|
|
|
|
|
|
def skip_sample(self, i):
|
2022-09-02 22:00:05 -06:00
|
|
|
return None
|
2022-08-30 15:06:47 -06:00
|
|
|
|
|
|
|
def get_caption(self, i):
|
|
|
|
example = self.examples[self.hashes[i]]
|
|
|
|
caption = open(example['text'], 'r').read()
|
|
|
|
caption = caption.replace(' ', ' ').replace('\n', ' ').lstrip().rstrip()
|
|
|
|
return caption
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.image_files)
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
example_ret = {}
|
|
|
|
try:
|
|
|
|
image_file = self.examples[self.hashes[i]]['image']
|
|
|
|
image = Image.open(image_file)
|
|
|
|
if not image.mode == "RGB":
|
|
|
|
image = image.convert("RGB")
|
|
|
|
except (OSError, ValueError) as e:
|
|
|
|
print(f'Error with {image_file} -- skipping {i}')
|
2022-09-02 22:00:05 -06:00
|
|
|
return None
|
2022-08-30 15:06:47 -06:00
|
|
|
|
|
|
|
try:
|
|
|
|
caption = self.get_caption(i)
|
|
|
|
if caption == None:
|
|
|
|
raise ValueError
|
|
|
|
except (OSError, ValueError) as e:
|
|
|
|
print(f'Error with caption of {image_file} -- skipping {i}')
|
|
|
|
return self.skip_sample(i)
|
|
|
|
|
|
|
|
example_ret['caption'] = caption
|
|
|
|
|
|
|
|
# default to score-sde preprocessing
|
2022-09-07 00:33:12 -06:00
|
|
|
if self.crop:
|
|
|
|
img = np.array(image).astype(np.uint8)
|
|
|
|
crop = min(img.shape[0], img.shape[1])
|
|
|
|
h, w, = img.shape[0], img.shape[1]
|
|
|
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
|
|
|
(w - crop) // 2:(w + crop) // 2]
|
|
|
|
image = Image.fromarray(img)
|
|
|
|
|
2022-08-30 15:06:47 -06:00
|
|
|
if self.size is not None:
|
|
|
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
|
|
|
|
|
|
|
image = self.flip(image)
|
|
|
|
image = np.array(image).astype(np.uint8)
|
|
|
|
example_ret["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
|
|
|
return example_ret
|
2022-09-02 22:00:05 -06:00
|
|
|
|
|
|
|
def get_image(self, i):
|
|
|
|
try:
|
|
|
|
image_file = self.examples[self.hashes[i]]['image']
|
|
|
|
image = Image.open(image_file)
|
|
|
|
if not image.mode == "RGB":
|
|
|
|
image = image.convert("RGB")
|
|
|
|
except Exception as e:
|
|
|
|
print(f'Error with {image_file} -- skipping {i}')
|
|
|
|
return self.skip_sample(i)
|
2022-08-30 15:06:47 -06:00
|
|
|
|
2022-09-02 22:00:05 -06:00
|
|
|
# default to score-sde preprocessing
|
2022-09-07 00:33:12 -06:00
|
|
|
if self.crop:
|
|
|
|
img = np.array(image).astype(np.uint8)
|
|
|
|
crop = min(img.shape[0], img.shape[1])
|
|
|
|
h, w, = img.shape[0], img.shape[1]
|
|
|
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
|
|
|
(w - crop) // 2:(w + crop) // 2]
|
|
|
|
image = Image.fromarray(img)
|
|
|
|
|
2022-09-02 22:00:05 -06:00
|
|
|
if self.size is not None:
|
|
|
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
|
|
|
|
|
|
|
image = self.flip(image)
|
|
|
|
return image
|
2022-09-07 00:33:12 -06:00
|
|
|
|
2022-09-02 22:00:05 -06:00
|
|
|
"""
|
2022-09-07 00:33:12 -06:00
|
|
|
if __name__ == "__main__":
|
|
|
|
dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val')
|
|
|
|
print(dataset.__len__())
|
|
|
|
example = dataset.__getitem__(0)
|
|
|
|
print(dataset.hashes[0])
|
2022-08-30 15:06:47 -06:00
|
|
|
print(example['caption'])
|
|
|
|
image = example['image']
|
|
|
|
image = ((image + 1) * 127.5).astype(np.uint8)
|
|
|
|
image = Image.fromarray(image)
|
2022-09-02 22:00:05 -06:00
|
|
|
image.save('example.png')
|
|
|
|
"""
|
|
|
|
|
2022-09-07 00:33:12 -06:00
|
|
|
from tqdm import tqdm
|
2022-09-02 22:00:05 -06:00
|
|
|
if __name__ == "__main__":
|
2022-09-20 23:00:32 -06:00
|
|
|
dataset = LocalBase('./danbooru-aesthetic', size=512)
|
|
|
|
import time
|
|
|
|
a = time.process_time()
|
|
|
|
for i in range(8):
|
|
|
|
dataset.get_image(i)
|
|
|
|
print('time:', time.process_time()-a)
|