[Flax] Fix unet and ddim scheduler (#594)

* [Flax] Fix unet and ddim scheduler

* correct

* finish
This commit is contained in:
Patrick von Platen 2022-09-20 23:29:09 +02:00 committed by GitHub
parent d934d3d795
commit 2345481c0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 14 deletions

View File

@ -19,7 +19,7 @@ import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def get_sinusoidal_embeddings(timesteps, embedding_dim):
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - freq_shift)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
@ -50,7 +50,8 @@ class FlaxTimestepEmbedding(nn.Module):
class FlaxTimesteps(nn.Module):
dim: int = 32
freq_shift: float = 1
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)

View File

@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
cross_attention_dim: int = 1280
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
freq_shift: int = 0
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
@ -100,7 +101,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
)
# time
self.time_proj = FlaxTimesteps(block_out_channels[0])
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
# down

View File

@ -354,7 +354,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "StableDiffusionSafetyChecker"
class_name = "FlaxStableDiffusionSafetyChecker"
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
@ -421,16 +421,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model = cached_folder
if issubclass(class_obj, FlaxModelMixin):
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = DummyChecker()
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if from_pt:
elif from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params

View File

@ -341,6 +341,10 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None

View File

@ -178,7 +178,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
rngs={},
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)

View File

@ -222,6 +222,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called

View File

@ -216,6 +216,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
eta = 0.0
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
@ -224,6 +227,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
@ -233,7 +237,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
std_dev_t = variance ** (0.5)
std_dev_t = eta * variance ** (0.5)
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output