add data download files
This commit is contained in:
parent
41c864dce2
commit
3a84dc1973
|
@ -0,0 +1,80 @@
|
|||
import os
|
||||
import json
|
||||
import requests
|
||||
import multiprocessing
|
||||
import tqdm
|
||||
|
||||
# downloads URLs from JSON
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--file', '-f', type=str, required=False)
|
||||
parser.add_argument('--out_dir', '-o', type=str, required=False)
|
||||
parser.add_argument('--threads', '-p', required=False, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
class DownloadManager():
|
||||
def __init__(self, max_threads=32):
|
||||
self.failed_downloads = []
|
||||
self.max_threads = max_threads
|
||||
|
||||
# args = (link, metadata, out_img_dir, out_text_dir)
|
||||
def download(self, args):
|
||||
try:
|
||||
r = requests.get(args[0], stream=True)
|
||||
with open(args[2] + args[0].split('/')[-1], 'wb') as f:
|
||||
for chunk in r.iter_content(1024):
|
||||
f.write(chunk)
|
||||
with open(args[3] + args[0].split('/')[-1].split('.')[0] + '.txt', 'w') as f:
|
||||
f.write(args[1])
|
||||
except:
|
||||
self.failed_downloads.append((args[0], args[1]))
|
||||
|
||||
def download_urls(self, file_path, out_dir):
|
||||
with open(file_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
os.makedirs(out_dir + '/img')
|
||||
os.makedirs(out_dir + '/text')
|
||||
|
||||
thread_args = []
|
||||
|
||||
print(f'Loading {file_path} for download on {self.max_threads} threads...')
|
||||
|
||||
# create initial thread_args
|
||||
for k, v in tqdm.tqdm(data.items()):
|
||||
thread_args.append((k, v, out_dir + 'img/', out_dir + 'text/'))
|
||||
|
||||
# divide thread_args into chunks divisible by max_threads
|
||||
chunks = []
|
||||
for i in range(0, len(thread_args), self.max_threads):
|
||||
chunks.append(thread_args[i:i+self.max_threads])
|
||||
|
||||
print(f'Downloading {len(thread_args)} images...')
|
||||
|
||||
# download chunks synchronously
|
||||
for chunk in tqdm.tqdm(chunks):
|
||||
with multiprocessing.Pool(self.max_threads) as p:
|
||||
p.map(self.download, chunk)
|
||||
|
||||
if len(self.failed_downloads) > 0:
|
||||
print("Failed downloads:")
|
||||
for i in self.failed_downloads:
|
||||
print(i[0])
|
||||
print("\n")
|
||||
"""
|
||||
# attempt to download any remaining failed downloads
|
||||
print('\nAttempting to download any failed downloads...')
|
||||
print('Failed downloads:', len(self.failed_downloads))
|
||||
if len(self.failed_downloads) > 0:
|
||||
for url in tqdm.tqdm(self.failed_downloads):
|
||||
self.download((url[0], url[1], out_dir + 'img/', out_dir + 'text/'))
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dm = DownloadManager(max_threads=args.threads)
|
||||
dm.download_urls(args.file, args.out_dir)
|
|
@ -0,0 +1,50 @@
|
|||
import threading
|
||||
import requests
|
||||
import json
|
||||
import random
|
||||
from pybooru import Danbooru
|
||||
from tqdm import tqdm
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--danbooru_username', '-user', type=str, required=False)
|
||||
parser.add_argument('--danbooru_key', '-key', type=str, required=False)
|
||||
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
|
||||
parser.add_argument('--posts', '-p', required=False, default=10000)
|
||||
parser.add_argument('--output', '-o', required=False, default='links.json')
|
||||
args = parser.parse_args()
|
||||
|
||||
class DanbooruScraper():
|
||||
def __init__(self, username, key):
|
||||
self.username = username
|
||||
self.key = key
|
||||
self.dbclient = Danbooru('danbooru', username=self.username, api_key=self.key)
|
||||
|
||||
# This will get danbooru urls and tags, put them in a dict, then write as a json file
|
||||
def get_urls(self, tags, num_posts, batch_size, file="data_urls.json"):
|
||||
dict = {}
|
||||
if num_posts % batch_size != 0:
|
||||
print("Error: num_posts must be divisible by batch_size")
|
||||
return
|
||||
for i in tqdm(range(num_posts//batch_size)):
|
||||
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i)
|
||||
if not urls:
|
||||
print(f'Empty results at {i}')
|
||||
break
|
||||
for j in urls:
|
||||
if 'file_url' in j:
|
||||
if j['file_url'] not in dict:
|
||||
d_url = j['file_url']
|
||||
d_tags = j['tag_string_copyright'] + " " + j['tag_string_character'] + " " + j['tag_string_general'] + " " + j['tag_string_artist']
|
||||
|
||||
dict[d_url] = d_tags
|
||||
else:
|
||||
print("Error: file_url not found")
|
||||
with open(file, 'w') as f:
|
||||
json.dump(dict, f)
|
||||
|
||||
# now test
|
||||
if __name__ == "__main__":
|
||||
ds = DanbooruScraper(args.danbooru_username, args.danbooru_key)
|
||||
ds.get_urls(args.tags, args.posts, 100, file=args.output)
|
Loading…
Reference in New Issue