wmts-exfiltrator/exfiltrate.py

185 lines
8.4 KiB
Python

import argparse
import base64
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from queue import Queue
import numpy as np
import rasterio
from PIL import Image
from rasterio import Affine
from tqdm import tqdm
from pkg.image import random_file_width
from pkg.spatial import deg2num
from pkg.thread import download_tile
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Exfiltrate data from WMS servers.')
parser.add_argument('base_url', help='The base URL for the WMS server. Example: https://wmts.nlsc.gov.tw/wmts/nURBAN/default/EPSG:3857/')
parser.add_argument('--zoom', type=int, required=True, help='The zoom level to use.')
parser.add_argument('--threads', type=int, default=10, help='Number of download threads to use.')
parser.add_argument('--referer', help='The content of the Referer header to send.')
parser.add_argument('--output', default='wmts-output', help='Output directory path.')
parser.add_argument('--proxy', action='store_true', help='Enable using a proxy.')
parser.add_argument('--tiff-threads', default=10, type=int, help='Number of threads to use when building TIFF. Default: auto')
parser.add_argument('--output-tiff', help='Path for output GeoTIFF. Default: wmts-output/output.tiff')
parser.add_argument('--bbox', required=True, type=str, metavar='Bounding Box', nargs='+', default=(None, None, None, None), help='Bounding Box of the area to download. Separate each value with a space. (top left lat, top left lon, bottom right lat, bottom right lon)')
# parser.add_argument('--extent', default=None, help='Specify an extent to break the output image to. This is the diagonal.')
parser.add_argument('--no-download', action='store_true', help="Don't do any downloading or image checking.")
args = parser.parse_args()
args.base_url = args.base_url.strip('/') + f'/{args.zoom}/'
base_output = Path(args.output).resolve().absolute().expanduser()
url_hash = base64.b64encode(args.base_url.encode()).decode('utf-8').strip('==')
tiles_output = base_output / url_hash / str(args.zoom)
tiles_output.mkdir(parents=True, exist_ok=True)
top_left_lat, top_left_lon, bottom_right_lat, bottom_right_lon = map(float, args.bbox)
min_col, min_row = deg2num(top_left_lat, top_left_lon, args.zoom)
max_col, max_row = deg2num(bottom_right_lat, bottom_right_lon, args.zoom)
if args.output_tiff:
output_tiff = Path(args.output_tiff)
else:
output_tiff = base_output / f'output-z{args.zoom}-{top_left_lat}x{top_left_lon}-{bottom_right_lat}x{bottom_right_lon}.tiff'
r_headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
'Accept': 'image/avif,image/webp,*/*',
'Accept-Language': 'en-US,en;q=0.5'
}
if args.referer:
r_headers['Referer'] = args.referer
tiles = []
retries = []
total_downloaded = 0
row_i = min_row
row_iter = range(min_row, max_row + 1)
row_bar = tqdm(total=len(row_iter), desc=f'Row {row_i}', postfix={'new_files': total_downloaded, 'failures': len(retries)})
for row in row_iter:
row_i = row
col_iter = range(min_col, max_col + 1)
# if args.no_download:
# for col in col_iter:
# tiles.append((row, col))
# else:
with (ThreadPoolExecutor(args.threads) as executor):
col_bar = tqdm(total=len(col_iter), leave=False)
futures = [executor.submit(download_tile, (row, col, args.base_url, r_headers, tiles_output, args.proxy, args.no_download)) for col in col_iter]
for future in as_completed(futures):
result = future.result()
if result:
result_row, result_col, new_image = result
if new_image == 'success':
total_downloaded += 1
tiles.append((result_row, result_col))
elif new_image == 'exist':
tiles.append((result_row, result_col))
elif new_image == 'failure':
retries.append((result_row, result_col))
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
col_bar.update()
row_bar.refresh()
col_bar.close()
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
row_bar.update()
row_bar.close()
col_bar = tqdm(total=len(retries), desc=f'Tile Retries')
with ThreadPoolExecutor(args.threads) as executor:
futures = [executor.submit(download_tile, (row, col, args.base_url, r_headers, tiles_output, args.proxy)) for row, col in retries]
for future in as_completed(futures):
result = future.result()
if result:
result_row, result_col, new_image = result
tiles.append((result_row, result_col))
if new_image == 'success':
total_downloaded += 1
elif new_image == 'failure':
col_bar.write(f'{(result_row, result_col)} failed!')
col_bar.update()
col_bar.close()
print(f'Downloaded {total_downloaded} images.')
print('Preparing data...')
tile_size = random_file_width(tiles_output)
# Define the number of rows and columns based on the bounding box
num_rows = max_row - min_row + 1
num_cols = max_col - min_col + 1
# Define the transformation from pixel coordinates to geographic coordinates, which is an Affine transformation that
# maps pixel coordinates in the image to geographic coordinates on the Earth's surface.
transform = (Affine.translation(top_left_lon, top_left_lat)
* Affine.scale((bottom_right_lon - top_left_lon) / (num_cols * tile_size),
(bottom_right_lat - top_left_lat) / (num_rows * tile_size)))
def worker(pbar):
while True:
row, col = q.get()
if row is None:
break
tile_file = tiles_output / f"{row}_{col}.png"
if not tile_file.is_file():
raise Exception(f'Tile does not exist: {tile_file}')
with Image.open(tile_file) as img:
tile_data = np.array(img, dtype=np.uint8)
# Remove the alpha channel
tile_data = tile_data[:, :, :3]
# Replace white pixels with NODATA
tile_data[np.all(tile_data == [255, 255, 255], axis=-1)] = [0, 0, 0]
# ArcGIS does not like pixels that have zeros in them, eg. (255, 0, 0). We need to convert the zeros to ones, eg. (255, 1, 1).
mask = np.any(tile_data == 0, axis=-1) & np.any(tile_data != 0, axis=-1) # Identify pixels where not all bands are zero and at least one band is zero.
for i in range(3): # Iterate over each band.
# For these pixels, set zero bands to one.
tile_data[mask & (tile_data[:, :, i] == 0), i] = 1
# Calculate the position of the tile in the image data array.
row_pos = (row - min_row) * tile_size
col_pos = (col - min_col) * tile_size
tile_data = np.transpose(tile_data, (2, 0, 1))
# Write the tile data to the GeoTIFF file
with lock:
dst.write(tile_data, window=rasterio.windows.Window(col_pos, row_pos, tile_size, tile_size), indexes=[1, 2, 3])
q.task_done()
pbar.update()
q = Queue()
lock = threading.Lock()
with rasterio.open(output_tiff, "w", driver="GTiff", height=num_rows * tile_size, width=num_cols * tile_size, count=3, dtype='uint8', crs='EPSG:4326', transform=transform, compress="DEFLATE", nodata=0) as dst:
with tqdm(total=len(tiles), desc='Building GeoTIFF') as pbar:
threads = []
for i in range(args.tiff_threads):
t = threading.Thread(target=worker, args=(pbar,))
t.start()
threads.append(t)
for row, col in tiles:
q.put((row, col))
# block until all tasks are done
q.join()
# stop workers
for i in range(args.tiff_threads):
q.put((None, None))
for t in threads:
t.join()