Refactor postprocessing to use generator to resolve OOM issues
This commit is contained in:
parent
ae6b30907d
commit
7c9c19b2a2
|
@ -11,10 +11,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||||
|
|
||||||
shared.state.begin(job="extras")
|
shared.state.begin(job="extras")
|
||||||
|
|
||||||
image_data = []
|
|
||||||
image_names = []
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
|
def get_images(extras_mode, image, image_folder, input_dir):
|
||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
|
@ -23,8 +22,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||||
else:
|
else:
|
||||||
image = Image.open(os.path.abspath(img.name))
|
image = Image.open(os.path.abspath(img.name))
|
||||||
fn = os.path.splitext(img.orig_name)[0]
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
image_data.append(image)
|
yield image, fn
|
||||||
image_names.append(fn)
|
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
assert input_dir, 'input directory not selected'
|
assert input_dir, 'input directory not selected'
|
||||||
|
@ -35,13 +33,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||||
image = Image.open(filename)
|
image = Image.open(filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
image_data.append(image)
|
yield image, filename
|
||||||
image_names.append(filename)
|
|
||||||
else:
|
else:
|
||||||
assert image, 'image not selected'
|
assert image, 'image not selected'
|
||||||
|
yield image, None
|
||||||
image_data.append(image)
|
|
||||||
image_names.append(None)
|
|
||||||
|
|
||||||
if extras_mode == 2 and output_dir != '':
|
if extras_mode == 2 and output_dir != '':
|
||||||
outpath = output_dir
|
outpath = output_dir
|
||||||
|
@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||||
|
|
||||||
infotext = ''
|
infotext = ''
|
||||||
|
|
||||||
for image, name in zip(image_data, image_names):
|
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
|
||||||
|
image_data: Image.Image
|
||||||
|
|
||||||
shared.state.textinfo = name
|
shared.state.textinfo = name
|
||||||
|
|
||||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
||||||
|
|
||||||
scripts.scripts_postproc.run(pp, args)
|
scripts.scripts_postproc.run(pp, args)
|
||||||
|
|
||||||
|
@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||||
if extras_mode != 2 or show_extras_results:
|
if extras_mode != 2 or show_extras_results:
|
||||||
outputs.append(pp.image)
|
outputs.append(pp.image)
|
||||||
|
|
||||||
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|
Loading…
Reference in New Issue