add tiff threading

This commit is contained in:
Cyberes 2023-11-06 18:28:44 -07:00
parent 664eb1a52f
commit 65953c9bde
2 changed files with 68 additions and 25 deletions

View File

@ -1,7 +1,9 @@
import argparse
import base64
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from queue import Queue
import numpy as np
import rasterio
@ -21,10 +23,11 @@ if __name__ == '__main__':
parser.add_argument('--referer', help='The content of the Referer header to send.')
parser.add_argument('--output', default='wmts-output', help='Output directory path.')
parser.add_argument('--proxy', action='store_true', help='Enable using a proxy.')
parser.add_argument('--tiff-threads', default=None, help='Number of threads to use when building TIFF. Default: auto')
parser.add_argument('--tiff-threads', default=10, type=int, help='Number of threads to use when building TIFF. Default: auto')
parser.add_argument('--output-tiff', help='Path for output GeoTIFF. Default: wmts-output/output.tiff')
parser.add_argument('--bbox', required=True, type=str, metavar='Bounding Box', nargs='+', default=(None, None, None, None), help='Bounding Box of the area to download. Separate each value with a space. (top left lat, top left lon, bottom right lat, bottom right lon)')
parser.add_argument('--extent', default=None, help='Specify an extent to break the output image to. This is the diagonal ')
# parser.add_argument('--extent', default=None, help='Specify an extent to break the output image to. This is the diagonal.')
parser.add_argument('--no-download', action='store_true', help="Don't do any downloading or image checking.")
args = parser.parse_args()
args.base_url = args.base_url.strip('/') + f'/{args.zoom}/'
@ -60,24 +63,29 @@ if __name__ == '__main__':
row_i = row
col_iter = range(min_col, max_col + 1)
col_bar = tqdm(total=len(col_iter), leave=False)
with (ThreadPoolExecutor(args.threads) as executor):
futures = [executor.submit(download_tile, (row, col, args.base_url, r_headers, tiles_output, args.proxy)) for col in col_iter]
for future in as_completed(futures):
result = future.result()
if result:
result_row, result_col, new_image = result
if new_image == 'success':
total_downloaded += 1
tiles.append((result_row, result_col))
elif new_image == 'exist':
tiles.append((result_row, result_col))
elif new_image == 'failure':
retries.append((result_row, result_col))
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
col_bar.update()
row_bar.refresh()
col_bar.close()
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
if args.no_download:
for col in col_iter:
tiles.append((row, col))
else:
with (ThreadPoolExecutor(args.threads) as executor):
futures = [executor.submit(download_tile, (row, col, args.base_url, r_headers, tiles_output, args.proxy)) for col in col_iter]
for future in as_completed(futures):
result = future.result()
if result:
result_row, result_col, new_image = result
if new_image == 'success':
total_downloaded += 1
tiles.append((result_row, result_col))
elif new_image == 'exist':
tiles.append((result_row, result_col))
elif new_image == 'failure':
retries.append((result_row, result_col))
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
col_bar.update()
row_bar.refresh()
col_bar.close()
row_bar.set_postfix({'new_files': total_downloaded, 'failures': len(retries)})
row_bar.update()
row_bar.close()
@ -98,6 +106,8 @@ if __name__ == '__main__':
print(f'Downloaded {total_downloaded} images.')
print('Preparing data...')
tile_size = random_file_width(tiles_output)
# Define the number of rows and columns based on the bounding box
@ -110,8 +120,13 @@ if __name__ == '__main__':
* Affine.scale((bottom_right_lon - top_left_lon) / (num_cols * tile_size),
(bottom_right_lat - top_left_lat) / (num_rows * tile_size)))
with rasterio.open(output_tiff, "w", driver="GTiff", height=num_rows * tile_size, width=num_cols * tile_size, count=3, dtype='uint8', crs='EPSG:4326', transform=transform, compress="DEFLATE", nodata=0) as dst:
for row, col in tqdm(tiles, desc='Building GeoTIFF'):
def worker(pbar):
while True:
row, col = q.get()
if row is None:
break
tile_file = tiles_output / f"{row}_{col}.png"
if not tile_file.is_file():
raise Exception(f'Tile does not exist: {tile_file}')
@ -138,4 +153,32 @@ if __name__ == '__main__':
tile_data = np.transpose(tile_data, (2, 0, 1))
# Write the tile data to the GeoTIFF file
dst.write(tile_data, window=rasterio.windows.Window(col_pos, row_pos, tile_size, tile_size), indexes=[1, 2, 3])
with lock:
dst.write(tile_data, window=rasterio.windows.Window(col_pos, row_pos, tile_size, tile_size), indexes=[1, 2, 3])
q.task_done()
pbar.update()
q = Queue()
lock = threading.Lock()
with rasterio.open(output_tiff, "w", driver="GTiff", height=num_rows * tile_size, width=num_cols * tile_size, count=3, dtype='uint8', crs='EPSG:4326', transform=transform, compress="DEFLATE", nodata=0) as dst:
with tqdm(total=len(tiles), desc='Building GeoTIFF') as pbar:
threads = []
for i in range(args.tiff_threads):
t = threading.Thread(target=worker, args=(pbar,))
t.start()
threads.append(t)
for row, col in tiles:
q.put((row, col))
# block until all tasks are done
q.join()
# stop workers
for i in range(args.tiff_threads):
q.put((None, None))
for t in threads:
t.join()

View File

@ -3,5 +3,5 @@ venv/bin/python3 exfiltrate.py \
--zoom 20 \
--referer https://maps.nlsc.gov.tw/ \
--bbox 25.076387 121.68951 25.068282 121.700175 \
--threads 30 \
--output ~/Downloads/wmts-output/
--threads 30 --tiff-threads 100 \
--output ~/Downloads/wmts-output/ --no-download