[Flax] Fix unet and ddim scheduler (#594)
* [Flax] Fix unet and ddim scheduler * correct * finish
This commit is contained in:
parent
d934d3d795
commit
2345481c0e
|
@ -19,7 +19,7 @@ import jax.numpy as jnp
|
|||
|
||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||
# less general (only handles the case we currently need).
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
|
@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
|||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = math.log(10000) / (half_dim - freq_shift)
|
||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||
|
@ -50,7 +50,8 @@ class FlaxTimestepEmbedding(nn.Module):
|
|||
|
||||
class FlaxTimesteps(nn.Module):
|
||||
dim: int = 32
|
||||
freq_shift: float = 1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim)
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
|
||||
|
|
|
@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
freq_shift: int = 0
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
# init input tensors
|
||||
|
@ -100,7 +101,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
# time
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0])
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
|
||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
|
||||
# down
|
||||
|
|
|
@ -354,7 +354,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
# TODO(Patrick, Suraj) - delete later
|
||||
if class_name == "DummyChecker":
|
||||
library_name = "stable_diffusion"
|
||||
class_name = "StableDiffusionSafetyChecker"
|
||||
class_name = "FlaxStableDiffusionSafetyChecker"
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
|
@ -421,16 +421,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
loaded_sub_model = cached_folder
|
||||
|
||||
if issubclass(class_obj, FlaxModelMixin):
|
||||
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
# make sure we don't initialize the weights to save time
|
||||
if name == "safety_checker":
|
||||
loaded_sub_model = DummyChecker()
|
||||
loaded_params = DummyChecker()
|
||||
else:
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
# make sure we don't initialize the weights to save time
|
||||
if from_pt:
|
||||
elif from_pt:
|
||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||
loaded_params = loaded_sub_model.params
|
||||
|
|
|
@ -341,6 +341,10 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
|
||||
|
|
|
@ -178,7 +178,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
rngs={},
|
||||
).sample
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
|
|
|
@ -222,6 +222,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
|
|
|
@ -216,6 +216,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
# - pred_sample_direction -> "direction pointing to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
|
||||
eta = 0.0
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||
|
||||
|
@ -224,6 +227,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
# 2. compute alphas, betas
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
|
@ -233,7 +237,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
|
||||
std_dev_t = variance ** (0.5)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||
|
|
Loading…
Reference in New Issue