lpw_stable_diffusion: Add is_cancelled_callback (#1053)

* [Community Pipelines] lpw_stable_diffusion: Add is_cancelled_callback

* [Community pipelines] lpw_stable_diffusion_onnx: Add is_cancelled_callback
This commit is contained in:
rafael 2022-11-02 06:51:18 -04:00 committed by GitHub
parent 8ee21915bf
commit bdbcaa9852
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 4 deletions

View File

@ -498,6 +498,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
@ -560,11 +561,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
@ -757,8 +762,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

View File

@ -435,6 +435,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
@ -496,11 +497,15 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
@ -668,8 +673,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]