Merge pull request #11957 from ljleb/pp-batch-list
Add postprocess_batch_list script callback
This commit is contained in:
commit
f7c0a963f1
|
@ -717,7 +717,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
all_prompts = p.all_prompts[:]
|
||||||
|
all_negative_prompts = p.all_negative_prompts[:]
|
||||||
|
all_seeds = p.all_seeds[:]
|
||||||
|
all_subseeds = p.all_subseeds[:]
|
||||||
|
|
||||||
|
# apply changes to generation data
|
||||||
|
all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts
|
||||||
|
all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts
|
||||||
|
all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds
|
||||||
|
all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds
|
||||||
|
|
||||||
|
# update p.all_negative_prompts in case extensions changed the size of the batch
|
||||||
|
# create_infotext below uses it
|
||||||
|
old_negative_prompts = p.all_negative_prompts
|
||||||
|
p.all_negative_prompts = all_negative_prompts
|
||||||
|
|
||||||
|
try:
|
||||||
|
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
||||||
|
finally:
|
||||||
|
# restore p.all_negative_prompts in case extensions changed the size of the batch
|
||||||
|
p.all_negative_prompts = old_negative_prompts
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
@ -806,6 +826,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
|
postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
|
||||||
|
p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
|
||||||
|
x_samples_ddim = postprocess_batch_list_args.images
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
p.batch_index = i
|
p.batch_index = i
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,11 @@ class PostprocessImageArgs:
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessBatchListArgs:
|
||||||
|
def __init__(self, images):
|
||||||
|
self.images = images
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
name = None
|
name = None
|
||||||
"""script's internal name derived from title"""
|
"""script's internal name derived from title"""
|
||||||
|
@ -156,6 +161,25 @@ class Script:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||||
|
This is useful when you want to update the entire batch instead of individual images.
|
||||||
|
|
||||||
|
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
|
||||||
|
If the number of images is different from the batch size when returning,
|
||||||
|
then the script has the responsibility to also update the following attributes in the processing object (p):
|
||||||
|
- p.prompts
|
||||||
|
- p.negative_prompts
|
||||||
|
- p.seeds
|
||||||
|
- p.subseeds
|
||||||
|
|
||||||
|
**kwargs will have same items as process_batch, and also:
|
||||||
|
- batch_number - index of current batch, from 0 to number of batches-1
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||||
"""
|
"""
|
||||||
Called for every image after it has been generated.
|
Called for every image after it has been generated.
|
||||||
|
@ -536,6 +560,14 @@ class ScriptRunner:
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue