Flax: Fix img2img and align with other pipeline (#1824)
* Flax: Add components function * Flax: Fix img2img and align with other pipeline * Flax: Fix PRNGKey type * Refactor strength to start_timestep * Fix preprocess images * Fix processed_images dimen * latents.shape -> latents_shape * Fix typo * Remove "static" comment * Remove unnecessary optional types in _generate * Apply doc-builder code style. Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
9ea7052f0e
commit
ab0e92fdc8
|
@ -189,7 +189,7 @@ class FlaxModelMixin:
|
|||
```"""
|
||||
return self._cast_floating_to(params, jnp.float16, mask)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
|
||||
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
|
|
@ -806,7 +806,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -475,6 +475,51 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
model = pipeline_class(**init_kwargs, dtype=dtype)
|
||||
return model, params
|
||||
|
||||
@staticmethod
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - set(["self"])
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
r"""
|
||||
|
||||
The `self.components` property can be useful to run different pipelines with the same weights and
|
||||
configurations to not have to re-allocate memory.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import (
|
||||
... FlaxStableDiffusionPipeline,
|
||||
... FlaxStableDiffusionImg2ImgPipeline,
|
||||
... )
|
||||
|
||||
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
|
||||
... )
|
||||
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
|
||||
```
|
||||
|
||||
Returns:
|
||||
A dictionary containing all the modules needed to initialize the pipeline.
|
||||
"""
|
||||
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||
components = {
|
||||
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||
}
|
||||
|
||||
if set(components.keys()) != expected_modules:
|
||||
raise ValueError(
|
||||
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||
f" {expected_modules} to be defined, but {components} are defined."
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
|
|
|
@ -764,7 +764,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
```
|
||||
|
||||
Returns:
|
||||
A dictionaly containing all the modules needed to initialize the pipeline.
|
||||
A dictionary containing all the modules needed to initialize the pipeline.
|
||||
"""
|
||||
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||
components = {
|
||||
|
|
|
@ -184,18 +184,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
):
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
|
@ -281,15 +277,15 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -41,6 +41,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
|||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
|
@ -106,6 +109,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
|
@ -116,10 +120,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
|
||||
if isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
processed_image = []
|
||||
for img in image:
|
||||
processed_image.append(preprocess(img, self.dtype))
|
||||
processed_image = jnp.array(processed_image).squeeze()
|
||||
|
||||
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
|
||||
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
|
@ -128,7 +130,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
return text_input.input_ids, processed_image
|
||||
return text_input.input_ids, processed_images
|
||||
|
||||
def _get_has_nsfw_concepts(self, features, params):
|
||||
has_nsfw_concepts = self.safety_checker(features, params)
|
||||
|
@ -164,12 +166,11 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
def get_timestep_start(self, num_inference_steps, strength, scheduler_state):
|
||||
def get_timestep_start(self, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
return t_start
|
||||
|
||||
|
@ -178,13 +179,14 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
guidance_scale: float = 7.5,
|
||||
debug: bool = False,
|
||||
prng_seed: jax.random.KeyArray,
|
||||
start_timestep: int,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
guidance_scale: float,
|
||||
noise: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
@ -197,18 +199,32 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
|
||||
|
||||
if neg_prompt_ids is None:
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
).input_ids
|
||||
else:
|
||||
uncond_input = neg_prompt_ids
|
||||
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if noise is None:
|
||||
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if noise.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}")
|
||||
|
||||
# Create init_latents
|
||||
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
|
||||
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
|
||||
init_latents = 0.18215 * init_latents
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
|
@ -241,19 +257,19 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
|
||||
)
|
||||
|
||||
t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state)
|
||||
latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
|
||||
latents = init_latents
|
||||
latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size)
|
||||
|
||||
if debug:
|
||||
latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(t_start, len(scheduler_state.timesteps)):
|
||||
for i in range(start_timestep, num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
else:
|
||||
latents, _ = jax.lax.fori_loop(
|
||||
t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state)
|
||||
)
|
||||
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
@ -268,14 +284,15 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
guidance_scale: float = 7.5,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
noise: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
debug: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
@ -287,12 +304,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
Array representing an image batch, that will be used as the starting point for the process.
|
||||
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
|
||||
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
|
@ -300,18 +322,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
noise (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
|
||||
by sampling using the supplied random `generator`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
|
@ -319,76 +340,109 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
|||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if isinstance(guidance_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
guidance_scale = guidance_scale[:, None]
|
||||
|
||||
start_timestep = self.get_timestep_start(num_inference_steps, strength)
|
||||
|
||||
if jit:
|
||||
image = _p_generate(
|
||||
images = _p_generate(
|
||||
self,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
strength,
|
||||
start_timestep,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
debug,
|
||||
noise,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
else:
|
||||
image = self._generate(
|
||||
images = self._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
strength,
|
||||
start_timestep,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
debug,
|
||||
noise,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_params = params["safety_checker"]
|
||||
image_uint8_casted = (image * 255).round().astype("uint8")
|
||||
num_devices, batch_size = image.shape[:2]
|
||||
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||
num_devices, batch_size = images.shape[:2]
|
||||
|
||||
image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit)
|
||||
image = np.asarray(image)
|
||||
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||
images = np.asarray(images)
|
||||
|
||||
# block images
|
||||
if any(has_nsfw_concept):
|
||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||
if is_nsfw:
|
||||
image[i] = np.asarray(image_uint8_casted[i])
|
||||
images[i] = np.asarray(images_uint8_casted[i])
|
||||
|
||||
image = image.reshape(num_devices, batch_size, height, width, 3)
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
images = np.asarray(images)
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# TODO: maybe use a config dict instead of so many static argnums
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10))
|
||||
# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0),
|
||||
static_broadcasted_argnums=(0, 5, 6, 7, 8),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
strength,
|
||||
start_timestep,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
debug,
|
||||
noise,
|
||||
neg_prompt_ids,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
start_timestep,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
noise,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
|||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
clip_input = jax.random.normal(rng, input_shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue