wmts-exfiltrator/exfiltrate.py

164 lines
8.3 KiB
Python
Raw Normal View History

2023-11-02 23:35:43 -06:00
import argparse
import base64
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import rasterio
from PIL import Image
from rasterio import Affine
from tqdm import tqdm
from pkg.image import random_file_width
from pkg.spatial import deg2num, lonlat_to_meters
2023-11-02 23:35:43 -06:00
from pkg.thread import download_tile
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Exfiltrate data from WMS servers.')
parser.add_argument('base_url', help='The base URL for the WMS server. Example: https://wmts.nlsc.gov.tw/wmts/nURBAN/default/EPSG:3857/')
parser.add_argument('--zoom', type=int, required=True, help='The zoom level to use.')
parser.add_argument('--threads', type=int, default=10, help='Number of download threads to use.')
parser.add_argument('--referer', help='The content of the Referer header to send.')
parser.add_argument('--output', default='wmts-output', help='Output directory path.')
2023-11-03 17:10:27 -06:00
parser.add_argument('--proxy', action='store_true', help='Enable using a proxy.')
parser.add_argument('--tiff-threads', default=None, help='Number of threads to use when building TIFF. Default: auto')
2023-11-02 23:35:43 -06:00
parser.add_argument('--output-tiff', help='Path for output GeoTIFF. Default: wmts-output/output.tiff')
parser.add_argument('--bbox', required=True, type=str, metavar='Bounding Box', nargs='+', default=(None, None, None, None), help='Bounding Box of the area to download. Separate each value with a space. (top left lat, top left lon, bottom right lat, bottom right lon)')
args = parser.parse_args()
args.base_url = args.base_url.strip('/') + f'/{args.zoom}/'
base_output = Path(args.output).resolve().absolute().expanduser()
url_hash = base64.b64encode(args.base_url.encode()).decode('utf-8').strip('==')
tiles_output = base_output / url_hash / str(args.zoom)
tiles_output.mkdir(parents=True, exist_ok=True)
top_left_lat, top_left_lon, bottom_right_lat, bottom_right_lon = map(float, args.bbox)
min_col, min_row = deg2num(top_left_lat, top_left_lon, args.zoom)
max_col, max_row = deg2num(bottom_right_lat, bottom_right_lon, args.zoom)
if args.output_tiff:
output_tiff = Path(args.output_tiff)
else:
output_tiff = base_output / f'output-z{args.zoom}-{top_left_lat}x{top_left_lon}-{bottom_right_lat}x{bottom_right_lon}.tiff'
r_headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
'Accept': 'image/avif,image/webp,*/*',
'Accept-Language': 'en-US,en;q=0.5'
}
if args.referer:
r_headers['Referer'] = args.referer
tiles = []
2023-11-03 16:31:53 -06:00
retries = []
2023-11-02 23:35:43 -06:00
total_downloaded = 0
row_i = min_row
2023-11-03 16:31:53 -06:00
row_iter = range(min_row, max_row + 1)
2023-11-04 20:15:10 -06:00
row_bar = tqdm(total=len(row_iter), desc=f'Row {row_i}', postfix={'new_files': total_downloaded, 'failures': len(retries)})
2023-11-03 16:31:53 -06:00
for row in row_iter:
2023-11-02 23:35:43 -06:00
row_i = row
col_iter = range(min_col, max_col + 1)
2023-11-05 20:46:50 -07:00
col_bar = tqdm(total=len(col_iter), leave=False)
with (ThreadPoolExecutor(args.threads) as executor):
futures = [executor.submit(download_tile, (row, col, args.base_url, r_headers, tiles_output, args.proxy)) for col in col_iter]
2023-11-02 23:35:43 -06:00
for future in as_completed(futures):
result = future.result()
if result:
result_row, result_col, new_image = result
2023-11-03 16:31:53 -06:00
if new_image == 'success':
2023-11-02 23:35:43 -06:00
total_downloaded += 1
2023-11-03 16:31:53 -06:00
tiles.append((result_row, result_col))
elif new_image == 'exist':
tiles.append((result_row, result_col))
elif new_image == 'failure':
retries.append((result_row, result_col))
2023-11-04 20:17:23 -06:00
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
col_bar.update()
row_bar.refresh()
2023-11-03 16:31:53 -06:00
col_bar.close()
2023-11-04 20:15:10 -06:00
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
2023-11-03 16:31:53 -06:00
row_bar.update()
row_bar.close()
col_bar = tqdm(total=len(retries), desc=f'Tile Retries')
with ThreadPoolExecutor(args.threads) as executor:
2023-11-03 17:10:27 -06:00
futures = [executor.submit(download_tile, (row, col, args.base_url, r_headers, tiles_output, args.proxy)) for row, col in retries]
2023-11-03 16:31:53 -06:00
for future in as_completed(futures):
result = future.result()
if result:
result_row, result_col, new_image = result
tiles.append((result_row, result_col))
if new_image == 'success':
total_downloaded += 1
elif new_image == 'failure':
col_bar.write(f'{(result_row, result_col)} failed!')
col_bar.update()
2023-11-05 20:23:29 -07:00
col_bar.close()
2023-11-02 23:35:43 -06:00
print(f'Downloaded {total_downloaded} images.')
tile_size = random_file_width(tiles_output)
# Define the number of rows and columns based on the bounding box
num_rows = max_row - min_row + 1
num_cols = max_col - min_col + 1
# Create an empty array to store the image data
image_data = np.empty((num_rows * tile_size, num_cols * tile_size, 3), dtype=np.uint8)
def build_tiff_data(task):
row, col = task
tile_file = tiles_output / f"{row}_{col}.png"
if not tile_file.is_file():
raise Exception(f'Tile does not exist: {tile_file}')
with Image.open(tile_file) as img:
tile_data = np.array(img)
# Remove the alpha channel
tile_data = tile_data[:, :, :3]
# Replace white pixels with NODATA
tile_data[np.all(tile_data == [255, 255, 255], axis=-1)] = [0, 0, 0]
2023-11-03 02:06:46 -06:00
# ArcGIS does not like pixels that have zeros in them, eg. (255, 0, 0). We need to convert the zeros to ones, eg. (255, 1, 1).
mask = np.any(tile_data == 0, axis=-1) & np.any(tile_data != 0, axis=-1) # Identify pixels where not all bands are zero and at least one band is zero.
for i in range(3): # Iterate over each band.
# For these pixels, set zero bands to one.
2023-11-03 16:31:53 -06:00
tile_data[mask & (tile_data[:, :, i] == 0), i] = 0.1
# Calculate the position of the tile in the image data array.
2023-11-02 23:35:43 -06:00
row_pos = (row - min_row) * tile_size
col_pos = (col - min_col) * tile_size
# Insert the tile data into the image data array at the correct spot.
2023-11-02 23:35:43 -06:00
image_data[row_pos:row_pos + tile_size, col_pos:col_pos + tile_size] = tile_data
with ThreadPoolExecutor(max_workers=args.tiff_threads) as executor:
2023-11-02 23:35:43 -06:00
futures = {executor.submit(build_tiff_data, task) for task in tiles}
for future in tqdm(as_completed(futures), total=len(futures), desc='Building TIFF'):
pass
# Transpose the image data array to the format (bands, rows, cols).
2023-11-02 23:35:43 -06:00
image_data = np.transpose(image_data, (2, 0, 1))
# Convert geographic coordinates to Web Mercator coordinates. Not 100% sure this is nessesary.
top_left_mx, top_left_my = lonlat_to_meters(top_left_lon, top_left_lat)
bottom_right_mx, bottom_right_my = lonlat_to_meters(bottom_right_lon, bottom_right_lat)
2023-11-02 23:35:43 -06:00
# Define the transformation from pixel coordinates to geographic coordinates, which is an Affine transformation that
# maps pixel coordinates in the image to geographic coordinates on the Earth's surface.
transform = (Affine.translation(top_left_lon, top_left_lat) # Create a translation transformation that shifts the image and set the origin of the image to the top-left corner of the bounding box.
# Create a scaling transformation that scales the image in the x and y directions to convert the pixel coordinates of the image to the geographic coordinates of the bounding box.
* Affine.scale((bottom_right_lon - top_left_lon) / image_data.shape[2], (bottom_right_lat - top_left_lat) / image_data.shape[1]))
2023-11-02 23:35:43 -06:00
# Write the image data to a GeoTIFF file
print('Saving to:', output_tiff)
start = time.time()
with rasterio.open(output_tiff, "w", driver="GTiff", height=num_rows * tile_size, width=num_cols * tile_size, count=3, dtype=str(image_data.dtype), crs='EPSG:4326', transform=transform, compress="DEFLATE", nodata=0) as dst:
2023-11-02 23:35:43 -06:00
dst.write(image_data, indexes=[1, 2, 3])
print(f'Saved in {int(time.time() - start)} seconds.')