fixing up filenames, bug fixes, concurrency limit
This commit is contained in:
parent
25278241c9
commit
62ddb83042
|
@ -8,7 +8,7 @@ import argparse
|
|||
import glob
|
||||
#import requests_async as requests
|
||||
import asyncio
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from typing import IO
|
||||
import aiofiles
|
||||
import re
|
||||
|
@ -35,6 +35,7 @@ def in_virtualenv():
|
|||
|
||||
def get_parser(**parser_kwargs):
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
do_not_download = True
|
||||
parser.add_argument(
|
||||
"--laion_dir",
|
||||
type=str,
|
||||
|
@ -106,27 +107,71 @@ def get_parser(**parser_kwargs):
|
|||
const=True,
|
||||
default=0,
|
||||
help="skips the first n parquet files on disk, useful to resume",
|
||||
),
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="additional logging of URL and TEXT prefiltering",
|
||||
),
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action='store_const',
|
||||
const=do_not_download,
|
||||
default=not(do_not_download),
|
||||
help="skips downloading, for checking filters, use with --verbose",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def cleanup_text(file_name: str):
|
||||
# TODO: can be improved
|
||||
file_name = re.sub(r'[^\x00-\x7F]+', '', file_name) # remove non-ascii
|
||||
|
||||
file_name = re.sub("<div.*<\/div>", "", file_name)
|
||||
file_name = re.sub("<span.*<\/span>", "", file_name)
|
||||
file_name = file_name.split("/")[-1]
|
||||
file_name = re.sub("<a.*<\/a>", "", file_name)
|
||||
file_name = file_name.replace('<p>', '').replace("</p>", "")
|
||||
file_name = file_name.replace('<strong>', '').replace("</strong>", "")
|
||||
file_name = file_name.replace('<em>', '').replace("</em>", "")
|
||||
|
||||
# remove forward slash from file_name
|
||||
file_name = file_name.replace("/", "").replace('&', 'and')
|
||||
file_name = re.sub(r'[^\x00-\x7F]+', '', file_name) # remove non-ascii
|
||||
|
||||
file_name = file_name.replace('\t', '').replace('\n', '').replace('\r', '')
|
||||
file_name = file_name.replace(' & ', ' and ').replace(' &', ' and').replace('& ', 'and ') \
|
||||
.replace(" + ", " and ").replace(" +", " and").replace("+ ", "and ")
|
||||
|
||||
file_name = file_name.replace('"', '').replace('\'', '').replace('?', '').replace(':','').replace('|','') \
|
||||
.replace('<', '').replace('>', '').replace('/', '').replace('\\', '').replace('*', '') \
|
||||
.replace('!', '').replace('@', '').replace('#', '').replace('$', '').replace('%', '') \
|
||||
.replace('^', '').replace('(', '').replace(')', '').replace('_', ' ') \
|
||||
.replace('\t', '').replace('\n', '').replace('\r', '')
|
||||
file_name = file_name.replace('\t', ' ').replace('\n', ' ').replace('\r', ' ')
|
||||
|
||||
file_name = file_name.replace('\"t"', ' ')
|
||||
|
||||
file_name = file_name.replace(" ♥ ","love").replace("♥ ","love ").replace(" ♥"," love") \
|
||||
.replace("♥"," love ")
|
||||
|
||||
# remove bad chars
|
||||
file_name = file_name.replace('\"', '').replace('?', '') \
|
||||
.replace('<', '').replace('>', '').replace('/', '').replace('*', '') \
|
||||
.replace('!', '').replace('#', '').replace('$', '').replace('%', '') \
|
||||
.replace('^', '').replace('(', '').replace(')', '')
|
||||
|
||||
# replace with space
|
||||
file_name = file_name.replace(':',' ').replace('|',' ').replace('@', '') \
|
||||
.replace("/", " ").replace("\\'", "\'").replace("\\", " ").replace('\\', ' ') \
|
||||
.replace('_', ' ').replace("=", " ")
|
||||
|
||||
# replace foreign chars
|
||||
file_name = file_name.replace('é', 'e').replace('è', 'e').replace('ê', 'e') \
|
||||
.replace('ë', 'e').replace('à', 'a').replace('â', 'a').replace('ä', 'a') \
|
||||
.replace('ç', 'c').replace('ù', 'u').replace('û', 'u').replace('ü', 'u') \
|
||||
.replace('ô', 'o').replace('ö', 'o').replace('ï', 'i').replace('î', 'i') \
|
||||
.replace('í', 'i').replace('ì', 'i').replace('ñ', 'n').replace('ß', 'ss') \
|
||||
.replace('á', 'a').replace('ã', 'a').replace('å', 'a').replace('æ', 'ae') \
|
||||
.replace('œ', 'oe').replace('ø', 'o').replace('ð', 'd').replace('þ', 'th') \
|
||||
.replace('ý', 'y').replace('ÿ', 'y').replace('ž', 'z').replace('ž', 'z') \
|
||||
.replace('š', 's').replace('đ', 'd').replace('ď', 'd').replace('č', 'c') \
|
||||
.replace('ć', 'c').replace('ř', 'r').replace('ŕ', 'r').replace('ľ', 'l') \
|
||||
.replace('ĺ', 'l').replace('ť', 't').replace('ň', 'n').replace('ņ', 'n') \
|
||||
.replace('ď', 'd').replace('Ď', 'D').replace('Ť', 'T').replace('Ň', 'N')
|
||||
|
||||
_MAX_LENGTH = 240
|
||||
if (len(file_name) > _MAX_LENGTH):
|
||||
|
@ -134,38 +179,12 @@ def cleanup_text(file_name: str):
|
|||
|
||||
return file_name
|
||||
|
||||
async def dummyToConsole_dict(dict):
|
||||
#print("{" + f"\"{dict[0]}\": \"{dict[1]}\"" +"},")
|
||||
print(dict[1])
|
||||
|
||||
def get_file_extension(image_url: str):
|
||||
result = "jpg"
|
||||
|
||||
if '?' in image_url:
|
||||
try:
|
||||
image_url = image_url.split('?')[0]
|
||||
#print(result)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
result = image_url.split(".")[-1]
|
||||
|
||||
if result in ["asp", "aspx", "ashx", "php", "jpeg"]:
|
||||
result = "jpg"
|
||||
|
||||
return result
|
||||
|
||||
async def call_http(image_url: str, out_file_name: str, session: ClientSession):
|
||||
async def call_http(image_url: str, session: aiohttp.ClientSession):
|
||||
#print(f"calling http and save to: {out_file_name}")
|
||||
global downloaded_count
|
||||
global http_timeout
|
||||
global current_parquet_file_downloaded_count
|
||||
try:
|
||||
if os.path.exists(out_file_name):
|
||||
print(f"{Fore.YELLOW} already exists: {Fore.LIGHTWHITE_EX}{out_file_name}{Fore.YELLOW}, skipping{Style.RESET_ALL}")
|
||||
return
|
||||
|
||||
#print(f" attempting to download to: {out_file_name}")
|
||||
res = await session.request(method="GET", url=image_url, timeout=http_timeout)
|
||||
|
||||
if (res.status == 200):
|
||||
|
@ -179,58 +198,86 @@ async def call_http(image_url: str, out_file_name: str, session: ClientSession):
|
|||
pass
|
||||
return None
|
||||
|
||||
|
||||
async def save_img(image: Image, text: str, out_dir: str):
|
||||
async def save_img(buffer: io.BytesIO, full_outpath: str):
|
||||
try:
|
||||
buffer = io.BytesIO(image)
|
||||
image = Image.open(buffer)
|
||||
format = image.format.lower()
|
||||
|
||||
if (format == "jpeg"):
|
||||
format = "jpg"
|
||||
|
||||
out_file_name = f"{out_dir}{text}.{format}"
|
||||
|
||||
async with aiofiles.open(out_file_name, "wb") as f:
|
||||
async with aiofiles.open(full_outpath, "wb") as f:
|
||||
await f.write(buffer.getbuffer())
|
||||
|
||||
except Exception as e:
|
||||
print(f"{Fore.YELLOW} *** Possible corrupt image for text: {Fore.LIGHTWHITE_EX}{text}{Style.RESET_ALL}")
|
||||
print(f"{Fore.RED} *** Unable to write to disk: {Fore.LIGHTWHITE_EX}{full_outpath}{Style.RESET_ALL}")
|
||||
print(f"{Fore.RED} *** ex: {Fore.LIGHTWHITE_EX}{str(e)}{Style.RESET_ALL}")
|
||||
pass
|
||||
|
||||
def get_outpath_filename(data: any, full_outpath_noext: str, clean_text: str):
|
||||
ext = "jpg"
|
||||
full_outpath = None
|
||||
buffer = None
|
||||
try:
|
||||
buffer = io.BytesIO(data)
|
||||
image = Image.open(buffer)
|
||||
ext = image.format.lower()
|
||||
|
||||
if (ext == "jpeg"):
|
||||
ext = "jpg"
|
||||
|
||||
full_outpath = f"{full_outpath_noext}.{ext}"
|
||||
except Exception as e:
|
||||
print(f"{Fore.YELLOW} *** Possible corrupt image for text: {Fore.LIGHTWHITE_EX}{clean_text}{Style.RESET_ALL}")
|
||||
print(f"{Fore.YELLOW} *** ex: {Fore.LIGHTWHITE_EX}{str(e)}{Style.RESET_ALL}")
|
||||
pass
|
||||
#del image
|
||||
#del img_bytes
|
||||
return full_outpath, buffer
|
||||
|
||||
async def download_image(image_url: str, text: str, outpath: IO, session: ClientSession):
|
||||
if outpath[-1] != "/":
|
||||
outpath += "/"
|
||||
async def download_image(image_url: str, clean_text: str, full_outpath_noext: IO, session: aiohttp.ClientSession):
|
||||
http_content = await call_http(image_url=image_url, session=session)
|
||||
|
||||
text = cleanup_text(text)
|
||||
file_extension = get_file_extension(image_url)
|
||||
out_file_name = f"{outpath}{text}.{file_extension}"
|
||||
buffer = None
|
||||
|
||||
#await dummyToConsole_dict(tuple([image_url, out_file_name]))
|
||||
img = await call_http(image_url, out_file_name, session)
|
||||
if (http_content is not None):
|
||||
full_outpath, buffer = get_outpath_filename(data=http_content, full_outpath_noext=full_outpath_noext, clean_text=clean_text)
|
||||
|
||||
if img is not None:
|
||||
if buffer is not None:
|
||||
global downloaded_count
|
||||
downloaded_count += 1
|
||||
await save_img(img, text, outpath)
|
||||
await save_img(buffer, full_outpath)
|
||||
|
||||
async def download_set_dict(opt, matches_dict: dict):
|
||||
async with ClientSession() as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
global downloaded_count
|
||||
current_parquet_file_downloaded_count = 0
|
||||
tasks = []
|
||||
for row in matches_dict:
|
||||
if downloaded_count < opt.limit:
|
||||
current_parquet_file_downloaded_count += 1
|
||||
pre_text=row["TEXT"]
|
||||
image_url=row["URL"]
|
||||
|
||||
clean_text = cleanup_text(pre_text)
|
||||
|
||||
full_outpath_noext = os.path.join(opt.out_dir, clean_text)
|
||||
|
||||
if (opt.verbose):
|
||||
print(f"{Fore.LIGHTGREEN_EX}***** Verbose log: ***** {Style.RESET_ALL}")
|
||||
print(f"{Fore.LIGHTGREEN_EX} url: {image_url}{Style.RESET_ALL}")
|
||||
print(f"{Fore.LIGHTGREEN_EX} text: {pre_text}{Style.RESET_ALL}")
|
||||
print(f"{Fore.LIGHTGREEN_EX} captn: {clean_text}{Style.RESET_ALL}")
|
||||
|
||||
if any(glob.glob(full_outpath_noext + ".*")):
|
||||
print(f"{Fore.YELLOW} already exists: {Fore.LIGHTWHITE_EX}{full_outpath_noext}{Fore.YELLOW}, skipping{Style.RESET_ALL}")
|
||||
return
|
||||
|
||||
if not opt.test:
|
||||
tasks.append(
|
||||
download_image(image_url=row["URL"], text=row["TEXT"], outpath=opt.out_dir, session=session)
|
||||
download_image(image_url=image_url, clean_text=clean_text, full_outpath_noext=full_outpath_noext, session=session)
|
||||
)
|
||||
else:
|
||||
current_parquet_file_downloaded_count += 1
|
||||
downloaded_count += 1
|
||||
if len(tasks) > 63:
|
||||
await asyncio.gather(*tasks)
|
||||
tasks = []
|
||||
else:
|
||||
print(f"{Fore.YELLOW} Limit reached: {opt.limit}, exiting...{Style.RESET_ALL}")
|
||||
break
|
||||
|
||||
if not opt.test & len(tasks) > 0:
|
||||
await asyncio.gather(*tasks)
|
||||
print(f"{Fore.LIGHTBLUE_EX} Downloaded chunk of {current_parquet_file_downloaded_count} images{Style.RESET_ALL}")
|
||||
|
||||
|
@ -238,10 +285,13 @@ def query_parquet(df: pd.DataFrame, opt):
|
|||
# TODO: efficiency, expression tree?
|
||||
matches = df
|
||||
|
||||
matches = matches[(matches.HEIGHT > opt.min_hw) & \
|
||||
(matches.WIDTH > opt.min_hw) & \
|
||||
(matches.punsafe < unsafe_threshhold) & \
|
||||
(matches.aesthetic > aesthetic_threshhold)]
|
||||
matches = matches[(matches.HEIGHT > opt.min_hw) & (matches.WIDTH > opt.min_hw)]
|
||||
|
||||
if 'punsafe' in matches.columns:
|
||||
matches = matches[(matches.punsafe > unsafe_threshhold)]
|
||||
|
||||
if ('aesthetic' in matches):
|
||||
matches = matches[(matches.aesthetic > aesthetic_threshhold)]
|
||||
|
||||
if opt.search_text:
|
||||
for word in opt.search_text.split(","):
|
||||
|
@ -261,13 +311,12 @@ def query_parquet(df: pd.DataFrame, opt):
|
|||
async def download_laion_matches(opt):
|
||||
print(f"{Fore.LIGHTBLUE_EX} Searching for {opt.search_text} in column: {opt.column} in {opt.laion_dir}/*.parquet{Style.RESET_ALL}")
|
||||
|
||||
for idx, file in enumerate(glob.iglob(f"{opt.laion_dir}/*.parquet")):
|
||||
for idx, file in enumerate(glob.iglob(os.path.join(opt.laion_dir, "*.parquet"))):
|
||||
if idx < opt.parquet_skip:
|
||||
print(f"{Fore.YELLOW} Skipping file {idx+1}/{opt.parquet_skip}: {file}{Style.RESET_ALL}")
|
||||
continue
|
||||
|
||||
global downloaded_count
|
||||
global aesthetic_threshhold
|
||||
if downloaded_count < opt.limit:
|
||||
print(f"{Fore.CYAN} reading file: {file}{Style.RESET_ALL}")
|
||||
|
||||
|
@ -300,6 +349,8 @@ if __name__ == '__main__':
|
|||
parser = get_parser()
|
||||
opt = parser.parse_args()
|
||||
|
||||
print(f"Test only mode: {opt.test}")
|
||||
|
||||
if(opt.search_text is None and opt.force is False):
|
||||
print(f"{Fore.YELLOW}** No search terms provided, exiting...")
|
||||
print(f"** Use --force to bypass safety to dump entire DB{Style.RESET_ALL}")
|
||||
|
|
Loading…
Reference in New Issue