From d0ec5be718e48b0072c2d28cfbf39c1bb8630763 Mon Sep 17 00:00:00 2001 From: harubaru Date: Thu, 22 Sep 2022 19:45:36 -0700 Subject: [PATCH] add centercropping --- danbooru_data/download.py | 42 +++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/danbooru_data/download.py b/danbooru_data/download.py index af4b1d6..1abce8e 100644 --- a/danbooru_data/download.py +++ b/danbooru_data/download.py @@ -22,25 +22,37 @@ parser = argparse.ArgumentParser() parser.add_argument('--file', '-f', type=str, required=False, default='links.json') parser.add_argument('--out_file', '-o', type=str, required=False, default='dataset-%06d.tar') parser.add_argument('--max_size', '-m', type=int, required=False, default=4294967296) -parser.add_argument('--threads', '-p', required=False, default=16) -parser.add_argument('--resize', '-r', required=False, default=768) +parser.add_argument('--threads', '-p', required=False, default=16, type=int) +parser.add_argument('--resize', '-r', required=False, default=768, type=int) args = parser.parse_args() -def resize_image(image: Image, max_size=(768,768)): - image = ImageOps.contain(image, max_size, Image.LANCZOS) - # resize to integer multiple of 64 - w, h = image.size - w, h = map(lambda x: x - x % 64, (w, h)) +def resize_image(image: Image, max_size=(512,512), center_crop=True): + if not center_crop: + image = ImageOps.contain(image, max_size, Image.LANCZOS) + # resize to integer multiple of 64 + w, h = image.size + w, h = map(lambda x: x - x % 64, (w, h)) - ratio = w / h - src_ratio = image.width / image.height + ratio = w / h + src_ratio = image.width / image.height - src_w = w if ratio > src_ratio else image.width * h // image.height - src_h = h if ratio <= src_ratio else image.height * w // image.width + src_w = w if ratio > src_ratio else image.width * h // image.height + src_h = h if ratio <= src_ratio else image.height * w // image.width - resized = image.resize((src_w, src_h), resample=Image.LANCZOS) - res = Image.new("RGB", (w, h)) - res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2)) + resized = image.resize((src_w, src_h), resample=Image.LANCZOS) + res = Image.new("RGB", (w, h)) + res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2)) + else: + if not image.mode == "RGB": + image = image.convert("RGB") + if center_crop: + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + res = Image.fromarray(img) + res = res.resize(max_size, resample=Image.LANCZOS) return res @@ -122,4 +134,4 @@ class DownloadManager(): if __name__ == '__main__': dm = DownloadManager(max_threads=args.threads) - dm.download_urls(args.file) \ No newline at end of file + dm.download_urls(args.file)