Refactor postprocessing to use generator to resolve OOM issues

This commit is contained in:
catboxanon 2023-08-11 11:32:12 -04:00
parent ae6b30907d
commit 7c9c19b2a2
1 changed files with 30 additions and 31 deletions

View File

@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
shared.state.begin(job="extras") shared.state.begin(job="extras")
image_data = []
image_names = []
outputs = [] outputs = []
if extras_mode == 1: def get_images(extras_mode, image, image_folder, input_dir):
for img in image_folder: if extras_mode == 1:
if isinstance(img, Image.Image): for img in image_folder:
image = img if isinstance(img, Image.Image):
fn = '' image = img
else: fn = ''
image = Image.open(os.path.abspath(img.name)) else:
fn = os.path.splitext(img.orig_name)[0] image = Image.open(os.path.abspath(img.name))
image_data.append(image) fn = os.path.splitext(img.orig_name)[0]
image_names.append(fn) yield image, fn
elif extras_mode == 2: elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected' assert input_dir, 'input directory not selected'
image_list = shared.listfiles(input_dir) image_list = shared.listfiles(input_dir)
for filename in image_list: for filename in image_list:
try: try:
image = Image.open(filename) image = Image.open(filename)
except Exception: except Exception:
continue continue
image_data.append(image) yield image, filename
image_names.append(filename) else:
else: assert image, 'image not selected'
assert image, 'image not selected' yield image, None
image_data.append(image)
image_names.append(None)
if extras_mode == 2 and output_dir != '': if extras_mode == 2 and output_dir != '':
outpath = output_dir outpath = output_dir
@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = '' infotext = ''
for image, name in zip(image_data, image_names): for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
image_data: Image.Image
shared.state.textinfo = name shared.state.textinfo = name
parameters, existing_pnginfo = images.read_info_from_image(image) parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters: if parameters:
existing_pnginfo["parameters"] = parameters existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
scripts.scripts_postproc.run(pp, args) scripts.scripts_postproc.run(pp, args)
@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode != 2 or show_extras_results: if extras_mode != 2 or show_extras_results:
outputs.append(pp.image) outputs.append(pp.image)
image_data.close()
devices.torch_gc() devices.torch_gc()
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''