diff --git a/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx index 48283241..09bfb853 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/img2img.mdx @@ -29,4 +29,8 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan - enable_attention_slicing - disable_attention_slicing - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention \ No newline at end of file + - disable_xformers_memory_efficient_attention + +[[autodoc]] FlaxStableDiffusionImg2ImgPipeline + - all + - __call__ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx b/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx index ce880491..33e84a63 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx @@ -30,4 +30,8 @@ Available checkpoints are: - enable_attention_slicing - disable_attention_slicing - enable_xformers_memory_efficient_attention - - disable_xformers_memory_efficient_attention \ No newline at end of file + - disable_xformers_memory_efficient_attention + +[[autodoc]] FlaxStableDiffusionInpaintPipeline + - all + - __call__ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx index 59061763..6b8d53bf 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/text2img.mdx @@ -39,3 +39,7 @@ Available Checkpoints are: - disable_xformers_memory_efficient_attention - enable_vae_tiling - disable_vae_tiling + +[[autodoc]] FlaxStableDiffusionPipeline + - all + - __call__ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 5895c6ec..6ace9e4b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -24,6 +24,7 @@ from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image + from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel @@ -33,7 +34,7 @@ from ...schedulers import ( FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) -from ...utils import deprecate, logging +from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -44,6 +45,39 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + + >>> from diffusers import FlaxStableDiffusionPipeline + + >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16 + ... ) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> prompt_ids = pipeline.prepare_inputs(prompt) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + + >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" @@ -272,6 +306,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, @@ -316,6 +351,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. + Examples: + Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 3424f68b..d87339dc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -23,6 +23,7 @@ from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image + from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel @@ -32,7 +33,7 @@ from ...schedulers import ( FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) -from ...utils import PIL_INTERPOLATION, logging +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -43,6 +44,64 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_img = init_img.resize((768, 512)) + + >>> prompts = "A fantasy landscape, trending on artstation" + + >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... revision="flax", + ... dtype=jnp.bfloat16, + ... ) + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + >>> prompt_ids, processed_image = pipeline.prepare_inputs( + ... prompt=[prompts] * num_samples, image=[init_img] * num_samples + ... ) + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipeline( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... strength=0.75, + ... num_inference_steps=50, + ... jit=True, + ... height=512, + ... width=768, + ... ).images + + >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + ``` +""" + class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" @@ -277,6 +336,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, @@ -332,6 +392,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + Examples: + Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index 2be6b082..973b8ea5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -24,6 +24,7 @@ from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image + from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel @@ -33,7 +34,7 @@ from ...schedulers import ( FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) -from ...utils import PIL_INTERPOLATION, deprecate, logging +from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -44,6 +45,60 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import PIL + >>> import requests + >>> from io import BytesIO + >>> from diffusers import FlaxStableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( + ... "xvjiarui/stable-diffusion-2-inpainting" + ... ) + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> init_image = num_samples * [init_image] + >>> mask_image = num_samples * [mask_image] + >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( + ... prompt, init_image, mask_image + ... ) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + >>> processed_masked_images = shard(processed_masked_images) + >>> processed_masks = shard(processed_masks) + + >>> images = pipeline( + ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True + ... ).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): r""" @@ -332,6 +387,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, @@ -378,6 +434,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. + Examples: + Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a