138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
from inspect import trace
|
|
import os
|
|
import json
|
|
import requests
|
|
import multiprocessing
|
|
import tqdm
|
|
import webdataset
|
|
from concurrent import futures
|
|
import io
|
|
import tarfile
|
|
import glob
|
|
import uuid
|
|
|
|
from PIL import Image, ImageOps
|
|
|
|
# downloads URLs from JSON
|
|
|
|
import argparse
|
|
import shutil
|
|
import numpy as np
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--file', '-f', type=str, required=False, default='links.json')
|
|
parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar')
|
|
parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296)
|
|
parser.add_argument('--threads', '-p', required=False, default=16, type=int)
|
|
parser.add_argument('--resize', '-r', required=False, default=512, type=int)
|
|
args = parser.parse_args()
|
|
|
|
def resize_image(image: Image, max_size=(512,512), center_crop=True):
|
|
if not center_crop:
|
|
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
|
# resize to integer multiple of 64
|
|
w, h = image.size
|
|
w, h = map(lambda x: x - x % 64, (w, h))
|
|
|
|
ratio = w / h
|
|
src_ratio = image.width / image.height
|
|
|
|
src_w = w if ratio > src_ratio else image.width * h // image.height
|
|
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
|
|
|
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
|
res = Image.new("RGB", (w, h))
|
|
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
|
else:
|
|
if not image.mode == "RGB":
|
|
image = image.convert("RGB")
|
|
img = np.array(image).astype(np.uint8)
|
|
crop = min(img.shape[0], img.shape[1])
|
|
h, w, = img.shape[0], img.shape[1]
|
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
|
(w - crop) // 2:(w + crop) // 2]
|
|
res = Image.fromarray(img)
|
|
res = res.resize(max_size, resample=Image.LANCZOS)
|
|
|
|
return res
|
|
|
|
class DownloadManager():
|
|
def __init__(self, max_threads: int = 32):
|
|
self.failed_downloads = []
|
|
self.max_threads = max_threads
|
|
self.uuid = str(uuid.uuid1())
|
|
|
|
# args = (post_id, link, caption_data)
|
|
def download(self, args_thread):
|
|
try:
|
|
image = Image.open(requests.get(args_thread[1], stream=True).raw).convert('RGB')
|
|
if args.resize:
|
|
image = resize_image(image, max_size=(args.resize, args.resize))
|
|
image_bytes = io.BytesIO()
|
|
image.save(image_bytes, format='PNG')
|
|
__key__ = '%07d' % int(args_thread[0])
|
|
image = image_bytes.getvalue()
|
|
caption = str(json.dumps(args_thread[2]))
|
|
|
|
with open(f'{self.uuid}/{__key__}.image', 'wb') as f:
|
|
f.write(image)
|
|
with open(f'{self.uuid}/{__key__}.caption', 'w') as f:
|
|
f.write(caption)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
print(e, traceback.print_exc())
|
|
self.failed_downloads.append((args_thread[0], args_thread[1], args_thread[2]))
|
|
|
|
def download_urls(self, file_path):
|
|
with open(file_path) as f:
|
|
data = json.load(f)
|
|
thread_args = []
|
|
|
|
delimiter = '\\' if os.name == 'nt' else '/'
|
|
|
|
self.uuid = (file_path.split(delimiter)[-1]).split('.')[0]
|
|
|
|
if not os.path.exists(f'./{self.uuid}'):
|
|
os.mkdir(f'{self.uuid}')
|
|
|
|
print(f'Loading {file_path} for downloading on {self.max_threads} threads... Writing to dataset {self.uuid}')
|
|
|
|
# create initial thread_args
|
|
for k, v in tqdm.tqdm(data.items()):
|
|
thread_args.append((k, v['file_url'], v))
|
|
|
|
# divide thread_args into chunks divisible by max_threads
|
|
chunks = []
|
|
for i in range(0, len(thread_args), self.max_threads):
|
|
chunks.append(thread_args[i:i+self.max_threads])
|
|
|
|
print(f'Downloading {len(thread_args)} images...')
|
|
|
|
# download chunks synchronously
|
|
for chunk in tqdm.tqdm(chunks):
|
|
with futures.ThreadPoolExecutor(args.threads) as p:
|
|
p.map(self.download, chunk)
|
|
|
|
if len(self.failed_downloads) > 0:
|
|
print("Failed downloads:")
|
|
for i in self.failed_downloads:
|
|
print(i[0])
|
|
print("\n")
|
|
|
|
# put things into tar
|
|
print(f'Writing webdataset to {self.uuid}')
|
|
archive = tarfile.open(f'{self.uuid}.tar', 'w')
|
|
files = glob.glob(f'{self.uuid}/*')
|
|
for f in tqdm.tqdm(files):
|
|
archive.add(f, f.split(delimiter)[-1])
|
|
|
|
archive.close()
|
|
|
|
print('Cleaning up...')
|
|
shutil.rmtree(self.uuid)
|
|
|
|
if __name__ == '__main__':
|
|
dm = DownloadManager(max_threads=args.threads)
|
|
dm.download_urls(args.file)
|