From 2345481c0e21f1bd84c0d85b866b57d34506d836 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Sep 2022 23:29:09 +0200 Subject: [PATCH] [Flax] Fix unet and ddim scheduler (#594) * [Flax] Fix unet and ddim scheduler * correct * finish --- src/diffusers/models/embeddings_flax.py | 7 ++++--- src/diffusers/models/unet_2d_condition_flax.py | 3 ++- src/diffusers/pipeline_flax_utils.py | 14 ++++++-------- src/diffusers/pipeline_utils.py | 4 ++++ .../pipeline_flax_stable_diffusion.py | 1 - src/diffusers/schedulers/scheduling_ddim.py | 1 + src/diffusers/schedulers/scheduling_ddim_flax.py | 6 +++++- 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 63442ab9..50ccc238 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -19,7 +19,7 @@ import jax.numpy as jnp # This is like models.embeddings.get_timestep_embedding (PyTorch) but # less general (only handles the case we currently need). -def get_sinusoidal_embeddings(timesteps, embedding_dim): +def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim): embeddings. :return: an [N x dim] tensor of positional embeddings. """ half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - freq_shift) emb = jnp.exp(jnp.arange(half_dim) * -emb) emb = timesteps[:, None] * emb[None, :] emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) @@ -50,7 +50,8 @@ class FlaxTimestepEmbedding(nn.Module): class FlaxTimesteps(nn.Module): dim: int = 32 + freq_shift: float = 1 @nn.compact def __call__(self, timesteps): - return get_sinusoidal_embeddings(timesteps, self.dim) + return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index d0fcd9f6..cab229f2 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): cross_attention_dim: int = 1280 dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 + freq_shift: int = 0 def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors @@ -100,7 +101,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ) # time - self.time_proj = FlaxTimesteps(block_out_channels[0]) + self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) # down diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index fca793c0..e65d95a3 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -354,7 +354,7 @@ class FlaxDiffusionPipeline(ConfigMixin): # TODO(Patrick, Suraj) - delete later if class_name == "DummyChecker": library_name = "stable_diffusion" - class_name = "StableDiffusionSafetyChecker" + class_name = "FlaxStableDiffusionSafetyChecker" is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None @@ -421,16 +421,14 @@ class FlaxDiffusionPipeline(ConfigMixin): loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): - # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here + loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + # make sure we don't initialize the weights to save time if name == "safety_checker": loaded_sub_model = DummyChecker() loaded_params = DummyChecker() - else: - loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) - params[name] = loaded_params - elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): - # make sure we don't initialize the weights to save time - if from_pt: + elif from_pt: # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) loaded_params = loaded_sub_model.params diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 847513bf..15334e24 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -341,6 +341,10 @@ class DiffusionPipeline(ConfigMixin): # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + if class_name.startswith("Flax"): + class_name = class_name[4:] + is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 6cca3767..7f068d71 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -178,7 +178,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, - rngs={}, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 32be871f..a5369b16 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -222,6 +222,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index dd5a87df..30c873b4 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -216,6 +216,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" + # TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function + eta = 0.0 + # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps @@ -224,6 +227,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): # 2. compute alphas, betas alpha_prod_t = alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called @@ -233,7 +237,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) - std_dev_t = variance ** (0.5) + std_dev_t = eta * variance ** (0.5) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output