diff --git a/scripts/download_laion.py b/scripts/download_laion.py index a2f75fa..d384484 100644 --- a/scripts/download_laion.py +++ b/scripts/download_laion.py @@ -8,7 +8,7 @@ import argparse import glob #import requests_async as requests import asyncio -from aiohttp import ClientSession +import aiohttp from typing import IO import aiofiles import re @@ -35,6 +35,7 @@ def in_virtualenv(): def get_parser(**parser_kwargs): parser = argparse.ArgumentParser(**parser_kwargs) + do_not_download = True parser.add_argument( "--laion_dir", type=str, @@ -106,27 +107,71 @@ def get_parser(**parser_kwargs): const=True, default=0, help="skips the first n parquet files on disk, useful to resume", + ), + parser.add_argument( + "--verbose", + type=bool, + nargs="?", + const=True, + default=False, + help="additional logging of URL and TEXT prefiltering", + ), + parser.add_argument( + "--test", + action='store_const', + const=do_not_download, + default=not(do_not_download), + help="skips downloading, for checking filters, use with --verbose", ) - + return parser def cleanup_text(file_name: str): - # TODO: can be improved - file_name = re.sub(r'[^\x00-\x7F]+', '', file_name) # remove non-ascii + # TODO: can be improved + file_name = re.sub("
', '').replace("
", "") + file_name = file_name.replace('', '').replace("", "") + file_name = file_name.replace('', '').replace("", "") - # remove forward slash from file_name - file_name = file_name.replace("/", "").replace('&', 'and') + file_name = re.sub(r'[^\x00-\x7F]+', '', file_name) # remove non-ascii - file_name = file_name.replace('\t', '').replace('\n', '').replace('\r', '') + file_name = file_name.replace(' & ', ' and ').replace(' &', ' and').replace('& ', 'and ') \ + .replace(" + ", " and ").replace(" +", " and").replace("+ ", "and ") - file_name = file_name.replace('"', '').replace('\'', '').replace('?', '').replace(':','').replace('|','') \ - .replace('<', '').replace('>', '').replace('/', '').replace('\\', '').replace('*', '') \ - .replace('!', '').replace('@', '').replace('#', '').replace('$', '').replace('%', '') \ - .replace('^', '').replace('(', '').replace(')', '').replace('_', ' ') \ - .replace('\t', '').replace('\n', '').replace('\r', '') + file_name = file_name.replace('\t', ' ').replace('\n', ' ').replace('\r', ' ') + + file_name = file_name.replace('\"t"', ' ') + + file_name = file_name.replace(" ♥ ","love").replace("♥ ","love ").replace(" ♥"," love") \ + .replace("♥"," love ") + + # remove bad chars + file_name = file_name.replace('\"', '').replace('?', '') \ + .replace('<', '').replace('>', '').replace('/', '').replace('*', '') \ + .replace('!', '').replace('#', '').replace('$', '').replace('%', '') \ + .replace('^', '').replace('(', '').replace(')', '') + + # replace with space + file_name = file_name.replace(':',' ').replace('|',' ').replace('@', '') \ + .replace("/", " ").replace("\\'", "\'").replace("\\", " ").replace('\\', ' ') \ + .replace('_', ' ').replace("=", " ") + + # replace foreign chars + file_name = file_name.replace('é', 'e').replace('è', 'e').replace('ê', 'e') \ + .replace('ë', 'e').replace('à', 'a').replace('â', 'a').replace('ä', 'a') \ + .replace('ç', 'c').replace('ù', 'u').replace('û', 'u').replace('ü', 'u') \ + .replace('ô', 'o').replace('ö', 'o').replace('ï', 'i').replace('î', 'i') \ + .replace('í', 'i').replace('ì', 'i').replace('ñ', 'n').replace('ß', 'ss') \ + .replace('á', 'a').replace('ã', 'a').replace('å', 'a').replace('æ', 'ae') \ + .replace('œ', 'oe').replace('ø', 'o').replace('ð', 'd').replace('þ', 'th') \ + .replace('ý', 'y').replace('ÿ', 'y').replace('ž', 'z').replace('ž', 'z') \ + .replace('š', 's').replace('đ', 'd').replace('ď', 'd').replace('č', 'c') \ + .replace('ć', 'c').replace('ř', 'r').replace('ŕ', 'r').replace('ľ', 'l') \ + .replace('ĺ', 'l').replace('ť', 't').replace('ň', 'n').replace('ņ', 'n') \ + .replace('ď', 'd').replace('Ď', 'D').replace('Ť', 'T').replace('Ň', 'N') _MAX_LENGTH = 240 if (len(file_name) > _MAX_LENGTH): @@ -134,38 +179,12 @@ def cleanup_text(file_name: str): return file_name -async def dummyToConsole_dict(dict): - #print("{" + f"\"{dict[0]}\": \"{dict[1]}\"" +"},") - print(dict[1]) - -def get_file_extension(image_url: str): - result = "jpg" - - if '?' in image_url: - try: - image_url = image_url.split('?')[0] - #print(result) - except ValueError: - pass - - result = image_url.split(".")[-1] - - if result in ["asp", "aspx", "ashx", "php", "jpeg"]: - result = "jpg" - - return result - -async def call_http(image_url: str, out_file_name: str, session: ClientSession): +async def call_http(image_url: str, session: aiohttp.ClientSession): #print(f"calling http and save to: {out_file_name}") global downloaded_count global http_timeout global current_parquet_file_downloaded_count try: - if os.path.exists(out_file_name): - print(f"{Fore.YELLOW} already exists: {Fore.LIGHTWHITE_EX}{out_file_name}{Fore.YELLOW}, skipping{Style.RESET_ALL}") - return - - #print(f" attempting to download to: {out_file_name}") res = await session.request(method="GET", url=image_url, timeout=http_timeout) if (res.status == 200): @@ -179,69 +198,100 @@ async def call_http(image_url: str, out_file_name: str, session: ClientSession): pass return None - -async def save_img(image: Image, text: str, out_dir: str): +async def save_img(buffer: io.BytesIO, full_outpath: str): try: - buffer = io.BytesIO(image) - image = Image.open(buffer) - format = image.format.lower() - - if (format == "jpeg"): - format = "jpg" - - out_file_name = f"{out_dir}{text}.{format}" - - async with aiofiles.open(out_file_name, "wb") as f: + async with aiofiles.open(full_outpath, "wb") as f: await f.write(buffer.getbuffer()) - except Exception as e: - print(f"{Fore.YELLOW} *** Possible corrupt image for text: {Fore.LIGHTWHITE_EX}{text}{Style.RESET_ALL}") + print(f"{Fore.RED} *** Unable to write to disk: {Fore.LIGHTWHITE_EX}{full_outpath}{Style.RESET_ALL}") + print(f"{Fore.RED} *** ex: {Fore.LIGHTWHITE_EX}{str(e)}{Style.RESET_ALL}") + pass + +def get_outpath_filename(data: any, full_outpath_noext: str, clean_text: str): + ext = "jpg" + full_outpath = None + buffer = None + try: + buffer = io.BytesIO(data) + image = Image.open(buffer) + ext = image.format.lower() + + if (ext == "jpeg"): + ext = "jpg" + + full_outpath = f"{full_outpath_noext}.{ext}" + except Exception as e: + print(f"{Fore.YELLOW} *** Possible corrupt image for text: {Fore.LIGHTWHITE_EX}{clean_text}{Style.RESET_ALL}") print(f"{Fore.YELLOW} *** ex: {Fore.LIGHTWHITE_EX}{str(e)}{Style.RESET_ALL}") pass - #del image - #del img_bytes + return full_outpath, buffer -async def download_image(image_url: str, text: str, outpath: IO, session: ClientSession): - if outpath[-1] != "/": - outpath += "/" +async def download_image(image_url: str, clean_text: str, full_outpath_noext: IO, session: aiohttp.ClientSession): + http_content = await call_http(image_url=image_url, session=session) - text = cleanup_text(text) - file_extension = get_file_extension(image_url) - out_file_name = f"{outpath}{text}.{file_extension}" + buffer = None - #await dummyToConsole_dict(tuple([image_url, out_file_name])) - img = await call_http(image_url, out_file_name, session) + if (http_content is not None): + full_outpath, buffer = get_outpath_filename(data=http_content, full_outpath_noext=full_outpath_noext, clean_text=clean_text) - if img is not None: + if buffer is not None: global downloaded_count downloaded_count += 1 - await save_img(img, text, outpath) + await save_img(buffer, full_outpath) async def download_set_dict(opt, matches_dict: dict): - async with ClientSession() as session: + async with aiohttp.ClientSession() as session: + global downloaded_count current_parquet_file_downloaded_count = 0 tasks = [] for row in matches_dict: if downloaded_count < opt.limit: current_parquet_file_downloaded_count += 1 - tasks.append( - download_image(image_url=row["URL"], text=row["TEXT"], outpath=opt.out_dir, session=session) - ) + pre_text=row["TEXT"] + image_url=row["URL"] + + clean_text = cleanup_text(pre_text) + + full_outpath_noext = os.path.join(opt.out_dir, clean_text) + + if (opt.verbose): + print(f"{Fore.LIGHTGREEN_EX}***** Verbose log: ***** {Style.RESET_ALL}") + print(f"{Fore.LIGHTGREEN_EX} url: {image_url}{Style.RESET_ALL}") + print(f"{Fore.LIGHTGREEN_EX} text: {pre_text}{Style.RESET_ALL}") + print(f"{Fore.LIGHTGREEN_EX} captn: {clean_text}{Style.RESET_ALL}") + + if any(glob.glob(full_outpath_noext + ".*")): + print(f"{Fore.YELLOW} already exists: {Fore.LIGHTWHITE_EX}{full_outpath_noext}{Fore.YELLOW}, skipping{Style.RESET_ALL}") + return + + if not opt.test: + tasks.append( + download_image(image_url=image_url, clean_text=clean_text, full_outpath_noext=full_outpath_noext, session=session) + ) + else: + current_parquet_file_downloaded_count += 1 + downloaded_count += 1 + if len(tasks) > 63: + await asyncio.gather(*tasks) + tasks = [] else: print(f"{Fore.YELLOW} Limit reached: {opt.limit}, exiting...{Style.RESET_ALL}") break - - await asyncio.gather(*tasks) + if not opt.test & len(tasks) > 0: + await asyncio.gather(*tasks) print(f"{Fore.LIGHTBLUE_EX} Downloaded chunk of {current_parquet_file_downloaded_count} images{Style.RESET_ALL}") def query_parquet(df: pd.DataFrame, opt): # TODO: efficiency, expression tree? matches = df - matches = matches[(matches.HEIGHT > opt.min_hw) & \ - (matches.WIDTH > opt.min_hw) & \ - (matches.punsafe < unsafe_threshhold) & \ - (matches.aesthetic > aesthetic_threshhold)] + matches = matches[(matches.HEIGHT > opt.min_hw) & (matches.WIDTH > opt.min_hw)] + + if 'punsafe' in matches.columns: + matches = matches[(matches.punsafe > unsafe_threshhold)] + + if ('aesthetic' in matches): + matches = matches[(matches.aesthetic > aesthetic_threshhold)] if opt.search_text: for word in opt.search_text.split(","): @@ -261,13 +311,12 @@ def query_parquet(df: pd.DataFrame, opt): async def download_laion_matches(opt): print(f"{Fore.LIGHTBLUE_EX} Searching for {opt.search_text} in column: {opt.column} in {opt.laion_dir}/*.parquet{Style.RESET_ALL}") - for idx, file in enumerate(glob.iglob(f"{opt.laion_dir}/*.parquet")): + for idx, file in enumerate(glob.iglob(os.path.join(opt.laion_dir, "*.parquet"))): if idx < opt.parquet_skip: print(f"{Fore.YELLOW} Skipping file {idx+1}/{opt.parquet_skip}: {file}{Style.RESET_ALL}") continue global downloaded_count - global aesthetic_threshhold if downloaded_count < opt.limit: print(f"{Fore.CYAN} reading file: {file}{Style.RESET_ALL}") @@ -300,6 +349,8 @@ if __name__ == '__main__': parser = get_parser() opt = parser.parse_args() + print(f"Test only mode: {opt.test}") + if(opt.search_text is None and opt.force is False): print(f"{Fore.YELLOW}** No search terms provided, exiting...") print(f"** Use --force to bypass safety to dump entire DB{Style.RESET_ALL}")