Add alternate dataloader
This commit is contained in:
parent
2e69358d50
commit
c5f2775beb
|
@ -0,0 +1,115 @@
|
|||
model:
|
||||
base_learning_rate: 1.5e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 512
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
wrap: false
|
||||
train:
|
||||
target: ldm.data.local.LocalDanbooruBase
|
||||
params:
|
||||
data_root: "./dataset"
|
||||
size: 768
|
||||
mode: "train"
|
||||
validation:
|
||||
target: ldm.data.local.LocalDanbooruBase
|
||||
params:
|
||||
data_root: "./dataset"
|
||||
size: 768
|
||||
mode: "val"
|
||||
val_split: 64
|
||||
|
||||
lightning:
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 500
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 500
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
ddim_steps: 50
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
val_check_interval: 5000000
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
|
@ -11,10 +11,99 @@ import random
|
|||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||
|
||||
import torchvision
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import torch
|
||||
|
||||
import re
|
||||
import json
|
||||
import io
|
||||
|
||||
def resize_image(image: Image, max_size=(768,768)):
|
||||
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||
# resize to integer multiple of 64
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 64, (w, h))
|
||||
|
||||
ratio = w / h
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (w, h))
|
||||
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||
|
||||
return res
|
||||
|
||||
class CaptionProcessor(object):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order):
|
||||
self.copyright_rate = copyright_rate
|
||||
self.character_rate = character_rate
|
||||
self.general_rate = general_rate
|
||||
self.artist_rate = artist_rate
|
||||
self.normalize = normalize
|
||||
self.caption_shuffle = caption_shuffle
|
||||
self.transforms = transforms
|
||||
self.max_size = max_size
|
||||
self.resize = resize
|
||||
self.random_order = random_order
|
||||
|
||||
def clean(self, text: str):
|
||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||
if self.caption_shuffle:
|
||||
text = text.split(' ')
|
||||
random.shuffle(text)
|
||||
text = ' '.join(text)
|
||||
if self.normalize:
|
||||
text = ', '.join([i.replace('_', ' ') for i in text.split(' ')]).lstrip(', ').rstrip(', ')
|
||||
return text
|
||||
|
||||
def get_key(self, val_dict, key, clean_val = True, cond_drop = 0.0, prepend_space = False, append_comma = False):
|
||||
space = ' ' if prepend_space else ''
|
||||
comma = ',' if append_comma else ''
|
||||
if random.random() < cond_drop:
|
||||
if (key in val_dict) and val_dict[key]:
|
||||
if clean_val:
|
||||
return space + self.clean(val_dict[key]) + comma
|
||||
else:
|
||||
return space + val_dict[key] + comma
|
||||
return ''
|
||||
|
||||
def __call__(self, sample):
|
||||
# preprocess caption
|
||||
caption_data = json.loads(sample['caption'])
|
||||
if not self.random_order:
|
||||
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
||||
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
||||
tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
||||
else:
|
||||
character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False)
|
||||
general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False)
|
||||
tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ')
|
||||
sample['caption'] = tag_str
|
||||
|
||||
# preprocess image
|
||||
image = sample['image']
|
||||
image = Image.open(io.BytesIO(image))
|
||||
if self.resize:
|
||||
image = resize_image(image, max_size=(self.max_size, self.max_size))
|
||||
image = self.transforms(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return sample
|
||||
|
||||
class LocalBase(Dataset):
|
||||
def __init__(self,
|
||||
data_root='./danbooru-aesthetic',
|
||||
size=512,
|
||||
size=768,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
crop=True,
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image, ImageOps
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
import glob
|
||||
|
||||
import random
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||
import torchvision
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import torch
|
||||
|
||||
import re
|
||||
import json
|
||||
import io
|
||||
|
||||
def resize_image(image: Image, max_size=(768,768)):
|
||||
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||
# resize to integer multiple of 64
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 64, (w, h))
|
||||
|
||||
ratio = w / h
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (w, h))
|
||||
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||
|
||||
return res
|
||||
|
||||
class CaptionProcessor(object):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms, max_size, resize, random_order):
|
||||
self.copyright_rate = copyright_rate
|
||||
self.character_rate = character_rate
|
||||
self.general_rate = general_rate
|
||||
self.artist_rate = artist_rate
|
||||
self.normalize = normalize
|
||||
self.caption_shuffle = caption_shuffle
|
||||
self.transforms = transforms
|
||||
self.max_size = max_size
|
||||
self.resize = resize
|
||||
self.random_order = random_order
|
||||
|
||||
def clean(self, text: str):
|
||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||
if self.caption_shuffle:
|
||||
text = text.split(' ')
|
||||
random.shuffle(text)
|
||||
text = ' '.join(text)
|
||||
if self.normalize:
|
||||
text = ', '.join([i.replace('_', ' ') for i in text.split(' ')]).lstrip(', ').rstrip(', ')
|
||||
return text
|
||||
|
||||
def get_key(self, val_dict, key, clean_val = True, cond_drop = 0.0, prepend_space = False, append_comma = False):
|
||||
space = ' ' if prepend_space else ''
|
||||
comma = ',' if append_comma else ''
|
||||
if random.random() < cond_drop:
|
||||
if (key in val_dict) and val_dict[key]:
|
||||
if clean_val:
|
||||
return space + self.clean(val_dict[key]) + comma
|
||||
else:
|
||||
return space + val_dict[key] + comma
|
||||
return ''
|
||||
|
||||
def __call__(self, sample):
|
||||
# preprocess caption
|
||||
caption_data = json.loads(sample['caption'])
|
||||
if not self.random_order:
|
||||
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
||||
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
||||
tag_str = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
||||
else:
|
||||
character = self.get_key(caption_data, 'tag_string_character', False, self.character_rate, False)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', False, self.copyright_rate, True, False)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', False, self.artist_rate, True, False)
|
||||
general = self.get_key(caption_data, 'tag_string_general', False, self.general_rate, True, False)
|
||||
tag_str = self.clean(f'{character}{copyright}{artist}{general}').lstrip().rstrip(' ')
|
||||
sample['caption'] = tag_str
|
||||
|
||||
# preprocess image
|
||||
image = sample['image']
|
||||
image = Image.open(io.BytesIO(image))
|
||||
if self.resize:
|
||||
image = resize_image(image, max_size=(self.max_size, self.max_size))
|
||||
image = self.transforms(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return sample
|
||||
|
||||
class LocalDanbooruBase(Dataset):
|
||||
def __init__(self,
|
||||
data_root='./danbooru-aesthetic',
|
||||
size=768,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
crop=True,
|
||||
shuffle=False,
|
||||
mode='train',
|
||||
val_split=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.shuffle=shuffle
|
||||
self.crop = crop
|
||||
|
||||
print('Fetching data.')
|
||||
|
||||
ext = ['image']
|
||||
self.image_files = []
|
||||
[self.image_files.extend(glob.glob(f'{data_root}' + '/*.' + e)) for e in ext]
|
||||
if mode == 'val':
|
||||
self.image_files = self.image_files[:len(self.image_files)//val_split]
|
||||
|
||||
print(f'Constructing image-caption map. Found {len(self.image_files)} images')
|
||||
|
||||
self.examples = {}
|
||||
self.hashes = []
|
||||
for i in self.image_files:
|
||||
hash = i[len(f'{data_root}/'):].split('.')[0]
|
||||
self.examples[hash] = {
|
||||
'image': i,
|
||||
'text': f'{data_root}/{hash}.caption'
|
||||
}
|
||||
self.hashes.append(hash)
|
||||
|
||||
print(f'image-caption map has {len(self.examples.keys())} examples')
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
image_transforms = []
|
||||
image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(flip_p)],)
|
||||
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||
|
||||
self.captionprocessor = CaptionProcessor(1.0, 1.0, 1.0, 1.0, True, True, image_transforms, 768, False, True)
|
||||
|
||||
def random_sample(self):
|
||||
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
||||
|
||||
def sequential_sample(self, i):
|
||||
if i >= self.__len__() - 1:
|
||||
return self.__getitem__(0)
|
||||
return self.__getitem__(i + 1)
|
||||
|
||||
def skip_sample(self, i):
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_files)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.get_image(i)
|
||||
|
||||
def get_image(self, i):
|
||||
image = {}
|
||||
try:
|
||||
image_file = self.examples[self.hashes[i]]['image']
|
||||
with open(image_file, 'rb') as f:
|
||||
image['image'] = f.read()
|
||||
text_file = self.examples[self.hashes[i]]['text']
|
||||
with open(text_file, 'rb') as f:
|
||||
image['caption'] = f.read()
|
||||
image = self.captionprocessor(image)
|
||||
except Exception as e:
|
||||
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
|
||||
return self.skip_sample(i)
|
||||
|
||||
return image
|
||||
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val')
|
||||
print(dataset.__len__())
|
||||
example = dataset.__getitem__(0)
|
||||
print(dataset.hashes[0])
|
||||
print(example['caption'])
|
||||
image = example['image']
|
||||
image = ((image + 1) * 127.5).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
image.save('example.png')
|
||||
"""
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
if __name__ == "__main__":
|
||||
dataset = LocalDanbooruBase('./links', size=768)
|
||||
import time
|
||||
a = time.process_time()
|
||||
for i in range(8):
|
||||
example = dataset.get_image(i)
|
||||
image = example['image']
|
||||
image = ((image + 1) * 127.5).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
image.save(f'example-{i}.png')
|
||||
print(example['caption'])
|
||||
print('time:', time.process_time()-a)
|
||||
"""
|
Loading…
Reference in New Issue