dataloader overhaul

This commit is contained in:
harubaru 2022-09-20 22:00:32 -07:00
parent 3546ae9b47
commit 01b9440d50
7 changed files with 396 additions and 50 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ example.png
scores.json
danbooru-aesthetic
logs
*.tar
# =========================================================================== #
# Python-related

View File

@ -0,0 +1,112 @@
model:
base_learning_rate: 1.5e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: ldm.data.localdanbooru.DanbooruWebDataModuleFromConfig
params:
tar_base: "links.tar"
batch_size: 1
num_workers: 1
size: 512
flip_p: 0.5
image_key: "image"
copyright_rate: 0.9
character_rate: 0.9
general_rate: 0.9
artist_rate: 0.9
normalize: true
caption_shuffle: true
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
ddim_steps: 50
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

@ -3,50 +3,69 @@ import json
import requests
import multiprocessing
import tqdm
import webdataset
from concurrent import futures
import io
import tarfile
import glob
import uuid
from PIL import Image
# downloads URLs from JSON
import argparse
import shutil
parser = argparse.ArgumentParser()
parser.add_argument('--file', '-f', type=str, required=False)
parser.add_argument('--out_dir', '-o', type=str, required=False)
parser.add_argument('--threads', '-p', required=False, default=32)
parser.add_argument('--file', '-f', type=str, required=False, default='links.json')
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
parser.add_argument('--threads', '-p', required=False, default=16)
args = parser.parse_args()
class DownloadManager():
def __init__(self, max_threads=32):
def __init__(self, max_threads: int = 32):
self.failed_downloads = []
self.max_threads = max_threads
self.uuid = str(uuid.uuid1())
# args = (link, metadata, out_img_dir, out_text_dir)
# args = (post_id, link, caption_data)
def download(self, args):
try:
r = requests.get(args[0], stream=True)
with open(args[2] + args[0].split('/')[-1], 'wb') as f:
for chunk in r.iter_content(1024):
f.write(chunk)
with open(args[3] + args[0].split('/')[-1].split('.')[0] + '.txt', 'w') as f:
f.write(args[1])
except:
self.failed_downloads.append((args[0], args[1]))
image = Image.open(requests.get(args[1], stream=True).raw).convert('RGB')
image_bytes = io.BytesIO()
image.save(image_bytes, format='PNG')
__key__ = '%07d' % int(args[0])
image = image_bytes.getvalue()
caption = str(json.dumps(args[2]))
with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
f.write(image)
with open(f'{self.uuid}/{__key__}.caption', 'w') as f:
f.write(caption)
except Exception as e:
print(e)
self.failed_downloads.append((args[0], args[1], args[2]))
def download_urls(self, file_path, out_dir):
def download_urls(self, file_path):
with open(file_path) as f:
data = json.load(f)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
os.makedirs(out_dir + '/img')
os.makedirs(out_dir + '/text')
thread_args = []
print(f'Loading {file_path} for download on {self.max_threads} threads...')
delimiter = '\\' if os.name == 'nt' else '/'
self.uuid = (file_path.split(delimiter)[-1]).split('.')[0]
if not os.path.exists(f'./{self.uuid}'):
os.mkdir(f'{self.uuid}')
print(f'Loading {file_path} for downloading on {self.max_threads} threads... Writing to dataset {self.uuid}')
# create initial thread_args
for k, v in tqdm.tqdm(data.items()):
thread_args.append((k, v, out_dir + 'img/', out_dir + 'text/'))
thread_args.append((k, v['file_url'], v))
# divide thread_args into chunks divisible by max_threads
chunks = []
@ -57,7 +76,7 @@ class DownloadManager():
# download chunks synchronously
for chunk in tqdm.tqdm(chunks):
with multiprocessing.Pool(self.max_threads) as p:
with futures.ThreadPoolExecutor(args.threads) as p:
p.map(self.download, chunk)
if len(self.failed_downloads) > 0:
@ -65,16 +84,19 @@ class DownloadManager():
for i in self.failed_downloads:
print(i[0])
print("\n")
"""
# attempt to download any remaining failed downloads
print('\nAttempting to download any failed downloads...')
print('Failed downloads:', len(self.failed_downloads))
if len(self.failed_downloads) > 0:
for url in tqdm.tqdm(self.failed_downloads):
self.download((url[0], url[1], out_dir + 'img/', out_dir + 'text/'))
"""
# put things into tar
print(f'Writing webdataset to {self.uuid}')
archive = tarfile.open(f'{self.uuid}.tar', 'w')
files = glob.glob(f'{self.uuid}/*')
for f in tqdm.tqdm(files):
archive.add(f, f.split(delimiter)[-1])
archive.close()
print('Cleaning up...')
shutil.rmtree(self.uuid)
if __name__ == '__main__':
dm = DownloadManager(max_threads=args.threads)
dm.download_urls(args.file, args.out_dir)
dm.download_urls(args.file)

View File

@ -11,10 +11,31 @@ parser = argparse.ArgumentParser()
parser.add_argument('--danbooru_username', '-user', type=str, required=False)
parser.add_argument('--danbooru_key', '-key', type=str, required=False)
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
parser.add_argument('--posts', '-p', required=False, default=10000)
parser.add_argument('--posts', '-p', required=False, type=int, default=10000)
parser.add_argument('--output', '-o', required=False, default='links.json')
args = parser.parse_args()
import re
def clean(text: str):
text = re.sub(r'\([^)]*\)', '', text)
text = text.split(' ')
new_text = []
for i in text:
new_text.append(i.lstrip('_').rstrip('_'))
text = set(new_text)
text = ' '.join(text)
text = text.lstrip().rstrip()
return text
def set_val(val_dict, new_dict, key, clean_val = True):
if (key in val_dict) and val_dict[key]:
if clean_val:
new_dict[key] = clean(val_dict[key])
else:
new_dict[key] = val_dict[key]
return new_dict
class DanbooruScraper():
def __init__(self, username, key):
self.username = username
@ -35,10 +56,19 @@ class DanbooruScraper():
for j in urls:
if 'file_url' in j:
if j['file_url'] not in dict:
d_url = j['file_url']
d_tags = j['tag_string_copyright'] + " " + j['tag_string_character'] + " " + j['tag_string_general'] + " " + j['tag_string_artist']
dict[d_url] = d_tags
d_tags = {}
if ('tag_string_copyright' in j) and j['tag_string_copyright']:
d_tags = set_val(j, d_tags, 'tag_string_copyright')
if ('tag_string_artist' in j) and j['tag_string_artist']:
d_tags = set_val(j, d_tags, 'tag_string_artist')
if ('tag_string_character' in j) and j['tag_string_character']:
d_tags = set_val(j, d_tags, 'tag_string_character')
if ('tag_string_general' in j) and j['tag_string_general']:
d_tags = set_val(j, d_tags, 'tag_string_general')
if ('tag_string_meta' in j) and j['tag_string_meta']:
d_tags = set_val(j, d_tags, 'tag_string_meta')
d_tags['file_url'] = j['file_url']
dict[j['id']] = d_tags
else:
print("Error: file_url not found")
with open(file, 'w') as f:
@ -47,4 +77,4 @@ class DanbooruScraper():
# now test
if __name__ == "__main__":
ds = DanbooruScraper(args.danbooru_username, args.danbooru_key)
ds.get_urls(args.tags, args.posts, 100, file=args.output)
ds.get_urls(args.tags, args.posts, 100, file=args.output)

View File

@ -21,6 +21,7 @@ class LocalBase(Dataset):
shuffle=False,
mode='train',
val_split=64,
):
super().__init__()
@ -153,16 +154,11 @@ if __name__ == "__main__":
image.save('example.png')
"""
"""
from tqdm import tqdm
if __name__ == "__main__":
dataset = LocalBase('../glide-finetune/touhou-portrait-aesthetic', size=512)
for i in tqdm(range(dataset.__len__())):
image = dataset.get_image(i)
if image == None:
continue
image.save(f'./danbooru-aesthetic/img/{dataset.hashes[i]}.png')
with open(f'./danbooru-aesthetic/txt/{dataset.hashes[i]}.txt', 'w') as f:
f.write(dataset.get_caption(i))
"""
dataset = LocalBase('./danbooru-aesthetic', size=512)
import time
a = time.process_time()
for i in range(8):
dataset.get_image(i)
print('time:', time.process_time()-a)

184
ldm/data/localdanbooru.py Normal file
View File

@ -0,0 +1,184 @@
import os
import numpy as np
import PIL
from PIL import Image
import random
PIL.Image.MAX_IMAGE_PIXELS = 933120000
import webdataset as wds
import torchvision
import pytorch_lightning as pl
import torch
import re
import json
import io
class CaptionProcessor(object):
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms):
self.copyright_rate = copyright_rate
self.character_rate = character_rate
self.general_rate = general_rate
self.artist_rate = artist_rate
self.normalize = normalize
self.caption_shuffle = caption_shuffle
self.transforms = transforms
def clean(self, text: str):
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
if self.caption_shuffle:
text = text.split(' ')
random.shuffle(text)
text = ' '.join(text)
if self.normalize:
text = ', '.join([i.replace('_', ' ') for i in text.split(' ')]).lstrip(', ').rstrip(', ')
return text
def get_key(self, val_dict, key, clean_val = True, cond_drop = 0.0, prepend_space = False, append_comma = False):
space = ' ' if prepend_space else ''
comma = ',' if append_comma else ''
if random.random() < cond_drop:
if (key in val_dict) and val_dict[key]:
if clean_val:
return space + self.clean(val_dict[key]) + comma
else:
return space + val_dict[key] + comma
return ''
def __call__(self, sample):
# preprocess caption
caption_data = json.loads(sample['caption'])
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
sample['caption'] = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
# preprocess image
image = sample['image']
image = self.transforms(Image.open(io.BytesIO(image)))
image = np.array(image).astype(np.uint8)
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
return sample
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
If `tensors` is True, `ndarray` objects are combined into
tensor batches.
:param dict samples: list of samples
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
:returns: single sample consisting of a batch
:rtype: dict
"""
keys = set.intersection(*[set(sample.keys()) for sample in samples])
batched = {key: [] for key in keys}
for s in samples:
[batched[key].append(s[key]) for key in batched]
result = {}
for key in batched:
if isinstance(batched[key][0], (int, float)):
if combine_scalars:
result[key] = np.array(list(batched[key]))
elif isinstance(batched[key][0], torch.Tensor):
if combine_tensors:
result[key] = torch.stack(list(batched[key]))
elif isinstance(batched[key][0], np.ndarray):
if combine_tensors:
result[key] = np.array(list(batched[key]))
else:
result[key] = list(batched[key])
return result
class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
def __init__(self, tar_base, batch_size, train=None, validation=None,
test=None, num_workers=4, size=512, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True,
**kwargs):
super().__init__(self)
print(f'Setting tar base to {tar_base}')
self.tar_base = tar_base
self.batch_size = batch_size
self.num_workers = num_workers
self.train = train
self.validation = validation
self.test = test
self.size = size
self.flip_p = flip_p
self.image_key = image_key
self.copyright_rate = copyright_rate
self.character_rate = character_rate
self.general_rate = general_rate
self.artist_rate = artist_rate
self.normalize = normalize
self.caption_shuffle = caption_shuffle
def make_loader(self, dataset_config, train=True):
image_transforms = []
image_transforms.extend([torchvision.transforms.CenterCrop(self.size),
torchvision.transforms.Resize(self.size),
torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
image_transforms = torchvision.transforms.Compose(image_transforms)
transform_dict = {}
transform_dict.update({self.image_key: image_transforms})
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms)
tars = os.path.join(self.tar_base)
dset = wds.WebDataset(
tars,
handler=wds.warn_and_continue).repeat().shuffle(1.0)
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
dset = (dset
.select(self.filter_keys)
)
if postprocess is not None:
dset = dset.map(postprocess)
dset = (dset
.batched(self.batch_size, partial=False,
collation_fn=dict_collation_fn)
)
loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
num_workers=self.num_workers)
return loader
def filter_keys(self, x):
return True
def train_dataloader(self):
return self.make_loader(self.train)
def val_dataloader(self):
return self.make_loader(self.validation, train=False)
def test_dataloader(self):
return self.make_loader(self.test, train=False)
def example():
from omegaconf import OmegaConf
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
config = OmegaConf.load("configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml")
datamod = DanbooruWebDataModuleFromConfig(**config["data"]["params"])
dataloader = datamod.train_dataloader()
for batch in dataloader:
print(batch["image"].shape)
print(batch['caption'])
break
if __name__ == '__main__':
#example()
pass

View File

@ -17,3 +17,4 @@ gradio
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
webdataset