callback postprocess_image_after_composite

This commit is contained in:
w-e-w 2024-01-16 14:18:20 +09:00
parent cb5b335acd
commit c1e04c63b3
2 changed files with 22 additions and 0 deletions

View File

@ -1029,6 +1029,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = apply_overlay(image, p.paste_to, overlay_image) image = apply_overlay(image, p.paste_to, overlay_image)
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image_after_composite(p, pp)
image = pp.image
if save_samples: if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)

View File

@ -262,6 +262,15 @@ class Script:
pass pass
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
Same as postprocess_image but after inpaint_full_res composite
So that it operates on the full image instead of the inpaint_full_res crop region.
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
@ -856,6 +865,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image_after_composite(p, pp, *script_args)
except Exception:
errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try: try: