[Flax] Fix unet and ddim scheduler (#594)
* [Flax] Fix unet and ddim scheduler * correct * finish
This commit is contained in:
parent
d934d3d795
commit
2345481c0e
|
@ -19,7 +19,7 @@ import jax.numpy as jnp
|
||||||
|
|
||||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||||
# less general (only handles the case we currently need).
|
# less general (only handles the case we currently need).
|
||||||
def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
||||||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||||
"""
|
"""
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - freq_shift)
|
||||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||||
emb = timesteps[:, None] * emb[None, :]
|
emb = timesteps[:, None] * emb[None, :]
|
||||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||||
|
@ -50,7 +50,8 @@ class FlaxTimestepEmbedding(nn.Module):
|
||||||
|
|
||||||
class FlaxTimesteps(nn.Module):
|
class FlaxTimesteps(nn.Module):
|
||||||
dim: int = 32
|
dim: int = 32
|
||||||
|
freq_shift: float = 1
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, timesteps):
|
def __call__(self, timesteps):
|
||||||
return get_sinusoidal_embeddings(timesteps, self.dim)
|
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
|
||||||
|
|
|
@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||||
cross_attention_dim: int = 1280
|
cross_attention_dim: int = 1280
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
freq_shift: int = 0
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
|
@ -100,7 +101,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
# time
|
# time
|
||||||
self.time_proj = FlaxTimesteps(block_out_channels[0])
|
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
|
||||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||||
|
|
||||||
# down
|
# down
|
||||||
|
|
|
@ -354,7 +354,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
# TODO(Patrick, Suraj) - delete later
|
# TODO(Patrick, Suraj) - delete later
|
||||||
if class_name == "DummyChecker":
|
if class_name == "DummyChecker":
|
||||||
library_name = "stable_diffusion"
|
library_name = "stable_diffusion"
|
||||||
class_name = "StableDiffusionSafetyChecker"
|
class_name = "FlaxStableDiffusionSafetyChecker"
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
loaded_sub_model = None
|
loaded_sub_model = None
|
||||||
|
@ -421,16 +421,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
loaded_sub_model = cached_folder
|
loaded_sub_model = cached_folder
|
||||||
|
|
||||||
if issubclass(class_obj, FlaxModelMixin):
|
if issubclass(class_obj, FlaxModelMixin):
|
||||||
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
|
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||||
|
params[name] = loaded_params
|
||||||
|
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||||
|
# make sure we don't initialize the weights to save time
|
||||||
if name == "safety_checker":
|
if name == "safety_checker":
|
||||||
loaded_sub_model = DummyChecker()
|
loaded_sub_model = DummyChecker()
|
||||||
loaded_params = DummyChecker()
|
loaded_params = DummyChecker()
|
||||||
else:
|
elif from_pt:
|
||||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
|
||||||
params[name] = loaded_params
|
|
||||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
|
||||||
# make sure we don't initialize the weights to save time
|
|
||||||
if from_pt:
|
|
||||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||||
loaded_params = loaded_sub_model.params
|
loaded_params = loaded_sub_model.params
|
||||||
|
|
|
@ -341,6 +341,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
# 3. Load each module in the pipeline
|
# 3. Load each module in the pipeline
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
|
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||||
|
if class_name.startswith("Flax"):
|
||||||
|
class_name = class_name[4:]
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
loaded_sub_model = None
|
loaded_sub_model = None
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||||
jnp.array(latents_input),
|
jnp.array(latents_input),
|
||||||
jnp.array(timestep, dtype=jnp.int32),
|
jnp.array(timestep, dtype=jnp.int32),
|
||||||
encoder_hidden_states=context,
|
encoder_hidden_states=context,
|
||||||
rngs={},
|
|
||||||
).sample
|
).sample
|
||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||||
|
|
|
@ -222,6 +222,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
# 2. compute alphas, betas
|
# 2. compute alphas, betas
|
||||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
|
||||||
# 3. compute predicted original sample from predicted noise also called
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
|
|
|
@ -216,6 +216,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
# - pred_sample_direction -> "direction pointing to x_t"
|
# - pred_sample_direction -> "direction pointing to x_t"
|
||||||
# - pred_prev_sample -> "x_t-1"
|
# - pred_prev_sample -> "x_t-1"
|
||||||
|
|
||||||
|
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
|
||||||
|
eta = 0.0
|
||||||
|
|
||||||
# 1. get previous step value (=t-1)
|
# 1. get previous step value (=t-1)
|
||||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||||
|
|
||||||
|
@ -224,6 +227,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
# 2. compute alphas, betas
|
# 2. compute alphas, betas
|
||||||
alpha_prod_t = alphas_cumprod[timestep]
|
alpha_prod_t = alphas_cumprod[timestep]
|
||||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||||
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
|
||||||
# 3. compute predicted original sample from predicted noise also called
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
|
@ -233,7 +237,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||||
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
|
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
|
||||||
std_dev_t = variance ** (0.5)
|
std_dev_t = eta * variance ** (0.5)
|
||||||
|
|
||||||
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||||
|
|
Loading…
Reference in New Issue