From ca45ff1ae6fdd5c2dcd754fde95dd29f49bd414b Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 24 Jul 2023 13:52:24 -0400 Subject: [PATCH 1/3] add postprocess_batch_list callback --- modules/processing.py | 24 +++++++++++++++++++++++- modules/scripts.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index a74a53027..c16404f47 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -717,7 +717,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): - return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt) + all_prompts = p.all_prompts[:] + all_seeds = p.all_seeds[:] + all_subseeds = p.all_subseeds[:] + + # apply changes to generation data + all_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.prompts + all_seeds[n * p.batch_size:(n + 1) * p.batch_size] = p.seeds + all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] = p.subseeds + + # update p.all_negative_prompts in case extensions changed the size of the batch + # create_infotext below uses it + old_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.negative_prompts + + try: + return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt) + finally: + # restore p.all_negative_prompts in case extensions changed the size of the batch + p.all_negative_prompts[n * p.batch_size:n * p.batch_size + len(p.negative_prompts)] = old_negative_prompts if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -806,6 +824,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) + postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim)) + p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n) + x_samples_ddim = postprocess_batch_list_args.images + for i, x_sample in enumerate(x_samples_ddim): p.batch_index = i diff --git a/modules/scripts.py b/modules/scripts.py index f34240a09..5b4edcac3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -16,6 +16,11 @@ class PostprocessImageArgs: self.image = image +class PostprocessBatchListArgs: + def __init__(self, images): + self.images = images + + class Script: name = None """script's internal name derived from title""" @@ -156,6 +161,25 @@ class Script: pass + def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs): + """ + Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor. + This is useful when you want to update the entire batch instead of individual images. + + You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc. + If the number of images is different from the batch size when returning, + then the script has the responsibility to also update the following attributes in the processing object (p): + - p.prompts + - p.negative_prompts + - p.seeds + - p.subseeds + + **kwargs will have same items as process_batch, and also: + - batch_number - index of current batch, from 0 to number of batches-1 + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -536,6 +560,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) + def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_batch_list(p, pp, *script_args, **kwargs) + except Exception: + errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: From 6b68b590321fcac2ad6d71c5aee1ac02687328d7 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 24 Jul 2023 15:38:52 -0400 Subject: [PATCH 2/3] use local vars --- modules/processing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c16404f47..7043477f5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -722,20 +722,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: all_subseeds = p.all_subseeds[:] # apply changes to generation data - all_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.prompts - all_seeds[n * p.batch_size:(n + 1) * p.batch_size] = p.seeds - all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] = p.subseeds + all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts + all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds + all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds # update p.all_negative_prompts in case extensions changed the size of the batch # create_infotext below uses it - old_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] - p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.negative_prompts + old_negative_prompts = p.all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] + p.all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts try: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt) finally: # restore p.all_negative_prompts in case extensions changed the size of the batch - p.all_negative_prompts[n * p.batch_size:n * p.batch_size + len(p.negative_prompts)] = old_negative_prompts + p.all_negative_prompts[iteration * p.batch_size:iteration * p.batch_size + len(p.negative_prompts)] = old_negative_prompts if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() From 5b06607476d1ef2c9d16fe8b21c786b2ca13b95c Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 24 Jul 2023 15:43:06 -0400 Subject: [PATCH 3/3] simplify --- modules/processing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7043477f5..6dc178e16 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -718,24 +718,26 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): all_prompts = p.all_prompts[:] + all_negative_prompts = p.all_negative_prompts[:] all_seeds = p.all_seeds[:] all_subseeds = p.all_subseeds[:] # apply changes to generation data all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts + all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds # update p.all_negative_prompts in case extensions changed the size of the batch # create_infotext below uses it - old_negative_prompts = p.all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] - p.all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts + old_negative_prompts = p.all_negative_prompts + p.all_negative_prompts = all_negative_prompts try: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt) finally: # restore p.all_negative_prompts in case extensions changed the size of the batch - p.all_negative_prompts[iteration * p.batch_size:iteration * p.batch_size + len(p.negative_prompts)] = old_negative_prompts + p.all_negative_prompts = old_negative_prompts if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings()