Flax controlnet (#2727)
* add contronet flax --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
parent
aa0531fa8d
commit
df91c44712
|
@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
|||
|
||||
## FlaxAutoencoderKL
|
||||
[[autodoc]] FlaxAutoencoderKL
|
||||
|
||||
## FlaxControlNetOutput
|
||||
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
|
||||
|
||||
## FlaxControlNetModel
|
||||
[[autodoc]] FlaxControlNetModel
|
||||
|
|
|
@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
|||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
## FlaxStableDiffusionControlNetPipeline
|
||||
[[autodoc]] FlaxStableDiffusionControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
|
|
@ -188,6 +188,7 @@ try:
|
|||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
|
@ -211,6 +212,7 @@ except OptionalDependencyNotAvailable:
|
|||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
FlaxStableDiffusionControlNetPipeline,
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
|
|
|
@ -30,5 +30,6 @@ if is_torch_available():
|
|||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
|
|
@ -0,0 +1,383 @@
|
|||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
FlaxUNetMidBlock2DCrossAttn,
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxControlNetOutput(BaseOutput):
|
||||
down_block_res_samples: jnp.ndarray
|
||||
mid_block_res_sample: jnp.ndarray
|
||||
|
||||
|
||||
class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
conditioning_embedding_channels: int
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256)
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.conv_in = nn.Conv(
|
||||
self.block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
blocks = []
|
||||
for i in range(len(self.block_out_channels) - 1):
|
||||
channel_in = self.block_out_channels[i]
|
||||
channel_out = self.block_out_channels[i + 1]
|
||||
conv1 = nn.Conv(
|
||||
channel_in,
|
||||
kernel_size=(3, 3),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
blocks.append(conv1)
|
||||
conv2 = nn.Conv(
|
||||
channel_out,
|
||||
kernel_size=(3, 3),
|
||||
strides=(2, 2),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
blocks.append(conv2)
|
||||
self.blocks = blocks
|
||||
|
||||
self.conv_out = nn.Conv(
|
||||
self.conditioning_embedding_channels,
|
||||
kernel_size=(3, 3),
|
||||
padding=((1, 1), (1, 1)),
|
||||
kernel_init=nn.initializers.zeros_init(),
|
||||
bias_init=nn.initializers.zeros_init(),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = nn.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = nn.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
@flax_register_to_config
|
||||
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
|
||||
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*):
|
||||
The size of the input sample.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
|
||||
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the cross attention features.
|
||||
dropout (`float`, *optional*, defaults to 0):
|
||||
Dropout probability for down, up and bottleneck blocks.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
|
||||
The channel order of conditional image. Will convert it to `rgb` if it's `bgr`
|
||||
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in conditioning_embedding layer
|
||||
|
||||
|
||||
"""
|
||||
sample_size: int = 32
|
||||
in_channels: int = 4
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
controlnet_conditioning_channel_order: str = "rgb"
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
||||
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
||||
controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
|
||||
controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
||||
|
||||
def setup(self):
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv(
|
||||
block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# time
|
||||
self.time_proj = FlaxTimesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
||||
)
|
||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
|
||||
self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=self.conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
only_cross_attention = self.only_cross_attention
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
||||
|
||||
attention_head_dim = self.attention_head_dim
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
|
||||
|
||||
# down
|
||||
down_blocks = []
|
||||
controlnet_down_blocks = []
|
||||
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv(
|
||||
output_channel,
|
||||
kernel_size=(1, 1),
|
||||
padding="VALID",
|
||||
kernel_init=nn.initializers.zeros_init(),
|
||||
bias_init=nn.initializers.zeros_init(),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(self.down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if down_block_type == "CrossAttnDownBlock2D":
|
||||
down_block = FlaxCrossAttnDownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=self.dropout,
|
||||
num_layers=self.layers_per_block,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
add_downsample=not is_final_block,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
down_block = FlaxDownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=self.dropout,
|
||||
num_layers=self.layers_per_block,
|
||||
add_downsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
down_blocks.append(down_block)
|
||||
|
||||
for _ in range(self.layers_per_block):
|
||||
controlnet_block = nn.Conv(
|
||||
output_channel,
|
||||
kernel_size=(1, 1),
|
||||
padding="VALID",
|
||||
kernel_init=nn.initializers.zeros_init(),
|
||||
bias_init=nn.initializers.zeros_init(),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv(
|
||||
output_channel,
|
||||
kernel_size=(1, 1),
|
||||
padding="VALID",
|
||||
kernel_init=nn.initializers.zeros_init(),
|
||||
bias_init=nn.initializers.zeros_init(),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
self.down_blocks = down_blocks
|
||||
self.controlnet_down_blocks = controlnet_down_blocks
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=mid_block_channel,
|
||||
dropout=self.dropout,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.controlnet_mid_block = nn.Conv(
|
||||
mid_block_channel,
|
||||
kernel_size=(1, 1),
|
||||
padding="VALID",
|
||||
kernel_init=nn.initializers.zeros_init(),
|
||||
bias_init=nn.initializers.zeros_init(),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
controlnet_cond,
|
||||
conditioning_scale: float = 1.0,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxControlNetOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
||||
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
||||
conditioning_scale: (`float`) the scale factor for controlnet outputs
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
channel_order = self.controlnet_conditioning_channel_order
|
||||
if channel_order == "bgr":
|
||||
controlnet_cond = jnp.flip(controlnet_cond, axis=1)
|
||||
|
||||
# 1. time
|
||||
if not isinstance(timesteps, jnp.ndarray):
|
||||
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
||||
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps.astype(dtype=jnp.float32)
|
||||
timesteps = jnp.expand_dims(timesteps, 0)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
sample += controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for down_block in self.down_blocks:
|
||||
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
||||
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
else:
|
||||
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
|
||||
# 5. contronet blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples += (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample *= conditioning_scale
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return FlaxControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
|
@ -249,6 +249,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
sample,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
down_block_additional_residuals=None,
|
||||
mid_block_additional_residual=None,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
||||
|
@ -291,9 +293,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if down_block_additional_residuals is not None:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals
|
||||
):
|
||||
down_block_res_sample += down_block_additional_residual
|
||||
new_down_block_res_samples += (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample += mid_block_additional_residual
|
||||
|
||||
# 5. up
|
||||
for up_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
|
||||
|
|
|
@ -124,6 +124,7 @@ except OptionalDependencyNotAvailable:
|
|||
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionControlNetPipeline,
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
|
|
|
@ -278,7 +278,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
>>> from diffusers import FlaxDPMSolverMultistepScheduler
|
||||
|
||||
>>> model_id = "runwayml/stable-diffusion-v1-5"
|
||||
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
>>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
... model_id,
|
||||
... subfolder="scheduler",
|
||||
... )
|
||||
|
@ -365,7 +365,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
@ -470,6 +470,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
# 4. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
model = pipeline_class(**init_kwargs, dtype=dtype)
|
||||
return model, params
|
||||
|
||||
|
|
|
@ -127,6 +127,7 @@ if is_transformers_available() and is_flax_available():
|
|||
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
|
||||
from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
|
||||
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
|
|
@ -0,0 +1,537 @@
|
|||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
)
|
||||
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
|
||||
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import jax
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from flax.jax_utils import replicate
|
||||
>>> from flax.training.common_utils import shard
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
|
||||
|
||||
|
||||
>>> def image_grid(imgs, rows, cols):
|
||||
... w, h = imgs[0].size
|
||||
... grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
... for i, img in enumerate(imgs):
|
||||
... grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
... return grid
|
||||
|
||||
|
||||
>>> def create_key(seed=0):
|
||||
... return jax.random.PRNGKey(seed)
|
||||
|
||||
|
||||
>>> rng = create_key(0)
|
||||
|
||||
>>> # get canny image
|
||||
>>> canny_image = load_image(
|
||||
... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
|
||||
... )
|
||||
|
||||
>>> prompts = "best quality, extremely detailed"
|
||||
>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
>>> # load control net and stable diffusion v1-5
|
||||
>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
||||
... )
|
||||
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
||||
... )
|
||||
>>> params["controlnet"] = controlnet_params
|
||||
|
||||
>>> num_samples = jax.device_count()
|
||||
>>> rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
||||
>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
||||
|
||||
>>> p_params = replicate(params)
|
||||
>>> prompt_ids = shard(prompt_ids)
|
||||
>>> negative_prompt_ids = shard(negative_prompt_ids)
|
||||
>>> processed_image = shard(processed_image)
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt_ids=prompt_ids,
|
||||
... image=processed_image,
|
||||
... params=p_params,
|
||||
... prng_seed=rng,
|
||||
... num_inference_steps=50,
|
||||
... neg_prompt_ids=negative_prompt_ids,
|
||||
... jit=True,
|
||||
... ).images
|
||||
|
||||
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
||||
>>> output_images = image_grid(output_images, num_samples // 4, 4)
|
||||
>>> output_images.save("generated_image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
|
||||
|
||||
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`FlaxAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`FlaxCLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
controlnet ([`FlaxControlNetModel`]:
|
||||
Provides additional conditioning to the unet during the denoising process.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
|
||||
[`FlaxDPMSolverMultistepScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: FlaxAutoencoderKL,
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
controlnet: FlaxControlNetModel,
|
||||
scheduler: Union[
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
return text_input.input_ids
|
||||
|
||||
def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
|
||||
if not isinstance(image, (Image.Image, list)):
|
||||
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
|
||||
|
||||
return processed_images
|
||||
|
||||
def _get_has_nsfw_concepts(self, features, params):
|
||||
has_nsfw_concepts = self.safety_checker(features, params)
|
||||
return has_nsfw_concepts
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
|
||||
has_nsfw_concepts = unshard(has_nsfw_concepts)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
|
||||
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
):
|
||||
height, width = image.shape[-2:]
|
||||
if height % 64 != 0 or width % 64 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
|
||||
|
||||
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
|
||||
if neg_prompt_ids is None:
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
).input_ids
|
||||
else:
|
||||
uncond_input = neg_prompt_ids
|
||||
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
image = jnp.concatenate([image] * 2)
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = jnp.concatenate([latents] * 2)
|
||||
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
|
||||
{"params": params["controlnet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, scheduler_state
|
||||
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
else:
|
||||
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt_ids (`jnp.array`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`jnp.array`):
|
||||
Array representing the ControlNet input condition. ControlNet use this input condition to generate
|
||||
guidance to Unet.
|
||||
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
|
||||
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
if isinstance(guidance_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
guidance_scale = guidance_scale[:, None]
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_params = params["safety_checker"]
|
||||
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||
num_devices, batch_size = images.shape[:2]
|
||||
|
||||
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||
images = np.asarray(images)
|
||||
|
||||
# block images
|
||||
if any(has_nsfw_concept):
|
||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||
if is_nsfw:
|
||||
images[i] = np.asarray(images_uint8_casted[i])
|
||||
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
images = np.asarray(images)
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),
|
||||
static_broadcasted_argnums=(0, 5),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_has_nsfw_concepts(pipe, features, params):
|
||||
return pipe._get_has_nsfw_concepts(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||
num_devices, batch_size = x.shape[:2]
|
||||
rest = x.shape[2:]
|
||||
return x.reshape(num_devices * batch_size, *rest)
|
||||
|
||||
|
||||
def preprocess(image, dtype):
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = jnp.array(image).astype(dtype) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
return image
|
|
@ -2,6 +2,21 @@
|
|||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
|
|
|
@ -2,6 +2,21 @@
|
|||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FlaxControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxModelMixin(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
|
||||
from diffusers.utils import is_flax_available, load_image, slow
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
"lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
params["controlnet"] = controlnet_params
|
||||
|
||||
prompts = "bird"
|
||||
num_samples = jax.device_count()
|
||||
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
|
||||
canny_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
p_params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_image = shard(processed_image)
|
||||
|
||||
images = pipe(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=p_params,
|
||||
prng_seed=rng,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
).images
|
||||
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array(
|
||||
[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
|
||||
)
|
||||
print(f"output_slice: {output_slice}")
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pose(self):
|
||||
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
"lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
params["controlnet"] = controlnet_params
|
||||
|
||||
prompts = "Chef in the kitchen"
|
||||
num_samples = jax.device_count()
|
||||
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
|
||||
pose_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
|
||||
)
|
||||
processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
p_params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_image = shard(processed_image)
|
||||
|
||||
images = pipe(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=p_params,
|
||||
prng_seed=rng,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
).images
|
||||
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array(
|
||||
[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
|
||||
)
|
||||
print(f"output_slice: {output_slice}")
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
Loading…
Reference in New Issue