dataloader overhaul
This commit is contained in:
parent
3546ae9b47
commit
01b9440d50
|
@ -10,6 +10,7 @@ example.png
|
|||
scores.json
|
||||
danbooru-aesthetic
|
||||
logs
|
||||
*.tar
|
||||
|
||||
# =========================================================================== #
|
||||
# Python-related
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
model:
|
||||
base_learning_rate: 1.5e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 512
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: ldm.data.localdanbooru.DanbooruWebDataModuleFromConfig
|
||||
params:
|
||||
tar_base: "links.tar"
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
size: 512
|
||||
flip_p: 0.5
|
||||
image_key: "image"
|
||||
copyright_rate: 0.9
|
||||
character_rate: 0.9
|
||||
general_rate: 0.9
|
||||
artist_rate: 0.9
|
||||
normalize: true
|
||||
caption_shuffle: true
|
||||
|
||||
|
||||
lightning:
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 500
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 500
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
ddim_steps: 50
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
val_check_interval: 5000000
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
|
@ -3,50 +3,69 @@ import json
|
|||
import requests
|
||||
import multiprocessing
|
||||
import tqdm
|
||||
import webdataset
|
||||
from concurrent import futures
|
||||
import io
|
||||
import tarfile
|
||||
import glob
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# downloads URLs from JSON
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--file', '-f', type=str, required=False)
|
||||
parser.add_argument('--out_dir', '-o', type=str, required=False)
|
||||
parser.add_argument('--threads', '-p', required=False, default=32)
|
||||
parser.add_argument('--file', '-f', type=str, required=False, default='links.json')
|
||||
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
|
||||
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
|
||||
parser.add_argument('--threads', '-p', required=False, default=16)
|
||||
args = parser.parse_args()
|
||||
|
||||
class DownloadManager():
|
||||
def __init__(self, max_threads=32):
|
||||
def __init__(self, max_threads: int = 32):
|
||||
self.failed_downloads = []
|
||||
self.max_threads = max_threads
|
||||
self.uuid = str(uuid.uuid1())
|
||||
|
||||
# args = (link, metadata, out_img_dir, out_text_dir)
|
||||
# args = (post_id, link, caption_data)
|
||||
def download(self, args):
|
||||
try:
|
||||
r = requests.get(args[0], stream=True)
|
||||
with open(args[2] + args[0].split('/')[-1], 'wb') as f:
|
||||
for chunk in r.iter_content(1024):
|
||||
f.write(chunk)
|
||||
with open(args[3] + args[0].split('/')[-1].split('.')[0] + '.txt', 'w') as f:
|
||||
f.write(args[1])
|
||||
except:
|
||||
self.failed_downloads.append((args[0], args[1]))
|
||||
image = Image.open(requests.get(args[1], stream=True).raw).convert('RGB')
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, format='PNG')
|
||||
__key__ = '%07d' % int(args[0])
|
||||
image = image_bytes.getvalue()
|
||||
caption = str(json.dumps(args[2]))
|
||||
|
||||
with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
|
||||
f.write(image)
|
||||
with open(f'{self.uuid}/{__key__}.caption', 'w') as f:
|
||||
f.write(caption)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.failed_downloads.append((args[0], args[1], args[2]))
|
||||
|
||||
def download_urls(self, file_path, out_dir):
|
||||
def download_urls(self, file_path):
|
||||
with open(file_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
os.makedirs(out_dir + '/img')
|
||||
os.makedirs(out_dir + '/text')
|
||||
|
||||
thread_args = []
|
||||
|
||||
print(f'Loading {file_path} for download on {self.max_threads} threads...')
|
||||
delimiter = '\\' if os.name == 'nt' else '/'
|
||||
|
||||
self.uuid = (file_path.split(delimiter)[-1]).split('.')[0]
|
||||
|
||||
if not os.path.exists(f'./{self.uuid}'):
|
||||
os.mkdir(f'{self.uuid}')
|
||||
|
||||
print(f'Loading {file_path} for downloading on {self.max_threads} threads... Writing to dataset {self.uuid}')
|
||||
|
||||
# create initial thread_args
|
||||
for k, v in tqdm.tqdm(data.items()):
|
||||
thread_args.append((k, v, out_dir + 'img/', out_dir + 'text/'))
|
||||
thread_args.append((k, v['file_url'], v))
|
||||
|
||||
# divide thread_args into chunks divisible by max_threads
|
||||
chunks = []
|
||||
|
@ -57,7 +76,7 @@ class DownloadManager():
|
|||
|
||||
# download chunks synchronously
|
||||
for chunk in tqdm.tqdm(chunks):
|
||||
with multiprocessing.Pool(self.max_threads) as p:
|
||||
with futures.ThreadPoolExecutor(args.threads) as p:
|
||||
p.map(self.download, chunk)
|
||||
|
||||
if len(self.failed_downloads) > 0:
|
||||
|
@ -65,16 +84,19 @@ class DownloadManager():
|
|||
for i in self.failed_downloads:
|
||||
print(i[0])
|
||||
print("\n")
|
||||
"""
|
||||
# attempt to download any remaining failed downloads
|
||||
print('\nAttempting to download any failed downloads...')
|
||||
print('Failed downloads:', len(self.failed_downloads))
|
||||
if len(self.failed_downloads) > 0:
|
||||
for url in tqdm.tqdm(self.failed_downloads):
|
||||
self.download((url[0], url[1], out_dir + 'img/', out_dir + 'text/'))
|
||||
"""
|
||||
|
||||
|
||||
# put things into tar
|
||||
print(f'Writing webdataset to {self.uuid}')
|
||||
archive = tarfile.open(f'{self.uuid}.tar', 'w')
|
||||
files = glob.glob(f'{self.uuid}/*')
|
||||
for f in tqdm.tqdm(files):
|
||||
archive.add(f, f.split(delimiter)[-1])
|
||||
|
||||
archive.close()
|
||||
|
||||
print('Cleaning up...')
|
||||
shutil.rmtree(self.uuid)
|
||||
|
||||
if __name__ == '__main__':
|
||||
dm = DownloadManager(max_threads=args.threads)
|
||||
dm.download_urls(args.file, args.out_dir)
|
||||
dm.download_urls(args.file)
|
|
@ -11,10 +11,31 @@ parser = argparse.ArgumentParser()
|
|||
parser.add_argument('--danbooru_username', '-user', type=str, required=False)
|
||||
parser.add_argument('--danbooru_key', '-key', type=str, required=False)
|
||||
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
|
||||
parser.add_argument('--posts', '-p', required=False, default=10000)
|
||||
parser.add_argument('--posts', '-p', required=False, type=int, default=10000)
|
||||
parser.add_argument('--output', '-o', required=False, default='links.json')
|
||||
args = parser.parse_args()
|
||||
|
||||
import re
|
||||
|
||||
def clean(text: str):
|
||||
text = re.sub(r'\([^)]*\)', '', text)
|
||||
text = text.split(' ')
|
||||
new_text = []
|
||||
for i in text:
|
||||
new_text.append(i.lstrip('_').rstrip('_'))
|
||||
text = set(new_text)
|
||||
text = ' '.join(text)
|
||||
text = text.lstrip().rstrip()
|
||||
return text
|
||||
|
||||
def set_val(val_dict, new_dict, key, clean_val = True):
|
||||
if (key in val_dict) and val_dict[key]:
|
||||
if clean_val:
|
||||
new_dict[key] = clean(val_dict[key])
|
||||
else:
|
||||
new_dict[key] = val_dict[key]
|
||||
return new_dict
|
||||
|
||||
class DanbooruScraper():
|
||||
def __init__(self, username, key):
|
||||
self.username = username
|
||||
|
@ -35,10 +56,19 @@ class DanbooruScraper():
|
|||
for j in urls:
|
||||
if 'file_url' in j:
|
||||
if j['file_url'] not in dict:
|
||||
d_url = j['file_url']
|
||||
d_tags = j['tag_string_copyright'] + " " + j['tag_string_character'] + " " + j['tag_string_general'] + " " + j['tag_string_artist']
|
||||
|
||||
dict[d_url] = d_tags
|
||||
d_tags = {}
|
||||
if ('tag_string_copyright' in j) and j['tag_string_copyright']:
|
||||
d_tags = set_val(j, d_tags, 'tag_string_copyright')
|
||||
if ('tag_string_artist' in j) and j['tag_string_artist']:
|
||||
d_tags = set_val(j, d_tags, 'tag_string_artist')
|
||||
if ('tag_string_character' in j) and j['tag_string_character']:
|
||||
d_tags = set_val(j, d_tags, 'tag_string_character')
|
||||
if ('tag_string_general' in j) and j['tag_string_general']:
|
||||
d_tags = set_val(j, d_tags, 'tag_string_general')
|
||||
if ('tag_string_meta' in j) and j['tag_string_meta']:
|
||||
d_tags = set_val(j, d_tags, 'tag_string_meta')
|
||||
d_tags['file_url'] = j['file_url']
|
||||
dict[j['id']] = d_tags
|
||||
else:
|
||||
print("Error: file_url not found")
|
||||
with open(file, 'w') as f:
|
||||
|
@ -47,4 +77,4 @@ class DanbooruScraper():
|
|||
# now test
|
||||
if __name__ == "__main__":
|
||||
ds = DanbooruScraper(args.danbooru_username, args.danbooru_key)
|
||||
ds.get_urls(args.tags, args.posts, 100, file=args.output)
|
||||
ds.get_urls(args.tags, args.posts, 100, file=args.output)
|
||||
|
|
|
@ -21,6 +21,7 @@ class LocalBase(Dataset):
|
|||
shuffle=False,
|
||||
mode='train',
|
||||
val_split=64,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -153,16 +154,11 @@ if __name__ == "__main__":
|
|||
image.save('example.png')
|
||||
"""
|
||||
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
if __name__ == "__main__":
|
||||
dataset = LocalBase('../glide-finetune/touhou-portrait-aesthetic', size=512)
|
||||
for i in tqdm(range(dataset.__len__())):
|
||||
image = dataset.get_image(i)
|
||||
if image == None:
|
||||
continue
|
||||
image.save(f'./danbooru-aesthetic/img/{dataset.hashes[i]}.png')
|
||||
with open(f'./danbooru-aesthetic/txt/{dataset.hashes[i]}.txt', 'w') as f:
|
||||
f.write(dataset.get_caption(i))
|
||||
|
||||
"""
|
||||
dataset = LocalBase('./danbooru-aesthetic', size=512)
|
||||
import time
|
||||
a = time.process_time()
|
||||
for i in range(8):
|
||||
dataset.get_image(i)
|
||||
print('time:', time.process_time()-a)
|
|
@ -0,0 +1,184 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||
|
||||
import webdataset as wds
|
||||
import torchvision
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import torch
|
||||
|
||||
import re
|
||||
import json
|
||||
import io
|
||||
|
||||
class CaptionProcessor(object):
|
||||
def __init__(self, copyright_rate, character_rate, general_rate, artist_rate, normalize, caption_shuffle, transforms):
|
||||
self.copyright_rate = copyright_rate
|
||||
self.character_rate = character_rate
|
||||
self.general_rate = general_rate
|
||||
self.artist_rate = artist_rate
|
||||
self.normalize = normalize
|
||||
self.caption_shuffle = caption_shuffle
|
||||
self.transforms = transforms
|
||||
|
||||
def clean(self, text: str):
|
||||
text = ' '.join(set([i.lstrip('_').rstrip('_') for i in re.sub(r'\([^)]*\)', '', text).split(' ')])).lstrip().rstrip()
|
||||
if self.caption_shuffle:
|
||||
text = text.split(' ')
|
||||
random.shuffle(text)
|
||||
text = ' '.join(text)
|
||||
if self.normalize:
|
||||
text = ', '.join([i.replace('_', ' ') for i in text.split(' ')]).lstrip(', ').rstrip(', ')
|
||||
return text
|
||||
|
||||
def get_key(self, val_dict, key, clean_val = True, cond_drop = 0.0, prepend_space = False, append_comma = False):
|
||||
space = ' ' if prepend_space else ''
|
||||
comma = ',' if append_comma else ''
|
||||
if random.random() < cond_drop:
|
||||
if (key in val_dict) and val_dict[key]:
|
||||
if clean_val:
|
||||
return space + self.clean(val_dict[key]) + comma
|
||||
else:
|
||||
return space + val_dict[key] + comma
|
||||
return ''
|
||||
|
||||
def __call__(self, sample):
|
||||
# preprocess caption
|
||||
caption_data = json.loads(sample['caption'])
|
||||
character = self.get_key(caption_data, 'tag_string_character', True, self.character_rate, False, True)
|
||||
copyright = self.get_key(caption_data, 'tag_string_copyright', True, self.copyright_rate, True, True)
|
||||
artist = self.get_key(caption_data, 'tag_string_artist', True, self.artist_rate, True, True)
|
||||
general = self.get_key(caption_data, 'tag_string_general', True, self.general_rate, True, False)
|
||||
sample['caption'] = f'{character}{copyright}{artist}{general}'.lstrip().rstrip(',')
|
||||
|
||||
# preprocess image
|
||||
image = sample['image']
|
||||
image = self.transforms(Image.open(io.BytesIO(image)))
|
||||
image = np.array(image).astype(np.uint8)
|
||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return sample
|
||||
|
||||
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
|
||||
If `tensors` is True, `ndarray` objects are combined into
|
||||
tensor batches.
|
||||
:param dict samples: list of samples
|
||||
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
|
||||
:returns: single sample consisting of a batch
|
||||
:rtype: dict
|
||||
"""
|
||||
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
||||
batched = {key: [] for key in keys}
|
||||
|
||||
for s in samples:
|
||||
[batched[key].append(s[key]) for key in batched]
|
||||
|
||||
result = {}
|
||||
for key in batched:
|
||||
if isinstance(batched[key][0], (int, float)):
|
||||
if combine_scalars:
|
||||
result[key] = np.array(list(batched[key]))
|
||||
elif isinstance(batched[key][0], torch.Tensor):
|
||||
if combine_tensors:
|
||||
result[key] = torch.stack(list(batched[key]))
|
||||
elif isinstance(batched[key][0], np.ndarray):
|
||||
if combine_tensors:
|
||||
result[key] = np.array(list(batched[key]))
|
||||
else:
|
||||
result[key] = list(batched[key])
|
||||
return result
|
||||
|
||||
|
||||
class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
||||
test=None, num_workers=4, size=512, flip_p=0.5, image_key='image', copyright_rate=0.9, character_rate=0.9, general_rate=0.9, artist_rate=0.9, normalize=True, caption_shuffle=True,
|
||||
**kwargs):
|
||||
super().__init__(self)
|
||||
print(f'Setting tar base to {tar_base}')
|
||||
self.tar_base = tar_base
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.train = train
|
||||
self.validation = validation
|
||||
self.test = test
|
||||
self.size = size
|
||||
self.flip_p = flip_p
|
||||
self.image_key = image_key
|
||||
self.copyright_rate = copyright_rate
|
||||
self.character_rate = character_rate
|
||||
self.general_rate = general_rate
|
||||
self.artist_rate = artist_rate
|
||||
self.normalize = normalize
|
||||
self.caption_shuffle = caption_shuffle
|
||||
|
||||
def make_loader(self, dataset_config, train=True):
|
||||
image_transforms = []
|
||||
image_transforms.extend([torchvision.transforms.CenterCrop(self.size),
|
||||
torchvision.transforms.Resize(self.size),
|
||||
torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||
|
||||
transform_dict = {}
|
||||
transform_dict.update({self.image_key: image_transforms})
|
||||
|
||||
postprocess = CaptionProcessor(copyright_rate=self.copyright_rate, character_rate=self.character_rate, general_rate=self.general_rate, artist_rate=self.artist_rate, normalize=self.normalize, caption_shuffle=self.caption_shuffle, transforms=image_transforms)
|
||||
|
||||
|
||||
tars = os.path.join(self.tar_base)
|
||||
|
||||
dset = wds.WebDataset(
|
||||
tars,
|
||||
handler=wds.warn_and_continue).repeat().shuffle(1.0)
|
||||
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
||||
dset = (dset
|
||||
.select(self.filter_keys)
|
||||
)
|
||||
if postprocess is not None:
|
||||
dset = dset.map(postprocess)
|
||||
dset = (dset
|
||||
.batched(self.batch_size, partial=False,
|
||||
collation_fn=dict_collation_fn)
|
||||
)
|
||||
|
||||
loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
|
||||
num_workers=self.num_workers)
|
||||
|
||||
return loader
|
||||
|
||||
def filter_keys(self, x):
|
||||
return True
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.make_loader(self.train)
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.make_loader(self.validation, train=False)
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.make_loader(self.test, train=False)
|
||||
|
||||
def example():
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
|
||||
config = OmegaConf.load("configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml")
|
||||
datamod = DanbooruWebDataModuleFromConfig(**config["data"]["params"])
|
||||
dataloader = datamod.train_dataloader()
|
||||
|
||||
for batch in dataloader:
|
||||
print(batch["image"].shape)
|
||||
print(batch['caption'])
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
#example()
|
||||
pass
|
|
@ -17,3 +17,4 @@ gradio
|
|||
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
|
||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||
webdataset
|
Loading…
Reference in New Issue