Add AudioLDM (#2232)

* Add AudioLDM

* up

* add vocoder

* start unet

* unconditional unet

* clap, vocoder and vae

* clean-up: conversion scripts

* fix: conversion script token_type_ids

* clean-up: pipeline docstring

* tests: from SD

* clean-up: cpu offload vocoder instead of safety checker

* feat: adapt tests to audioldm

* feat: add docs

* clean-up: amend pipeline docstrings

* clean-up: make style

* clean-up: make fix-copies

* fix: add doc path to toctree

* clean-up: args for conversion script

* clean-up: paths to checkpoints

* fix: use conditional unet

* clean-up: make style

* fix: type hints for UNet

* clean-up: docstring for UNet

* clean-up: make style

* clean-up: remove duplicate in docstring

* clean-up: make style

* clean-up: make fix-copies

* clean-up: move imports to start in code snippet

* fix: pass cross_attention_dim as a list/tuple to unet

* clean-up: make fix-copies

* fix: update checkpoint path

* fix: unet cross_attention_dim in tests

* film embeddings -> class embeddings

* Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>

* fix: unet film embed to use existing args

* fix: unet tests to use existing args

* fix: make style

* fix: transformers import and version in init

* clean-up: make style

* Revert "clean-up: make style"

This reverts commit 5d6d1f8b324f5583e7805dc01e2c86e493660d66.

* clean-up: make style

* clean-up: use pipeline tester mixin tests where poss

* clean-up: skip attn slicing test

* fix: add torch dtype to docs

* fix: remove conversion script out of src

* fix: remove .detach from 1d waveform

* fix: reduce default num inf steps

* fix: swap height/width -> audio_length_in_s

* clean-up: make style

* fix: remove nightly tests

* fix: imports in conversion script

* clean-up: slim-down to two slow tests

* clean-up: slim-down fast tests

* fix: batch consistent tests

* clean-up: make style

* clean-up: remove vae slicing fast test

* clean-up: propagate changes to doc

* fix: increase test tol to 1e-2

* clean-up: finish docs

* clean-up: make style

* feat: vocoder / VAE compatibility check

* feat: possibly expand / cut audio waveform

* fix: pipeline call signature test

* fix: slow tests output len

* clean-up: make style

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
Sanchit Gandhi 2023-03-23 19:00:21 +01:00 committed by GitHub
parent 1870fb05a9
commit b94880e536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 2318 additions and 24 deletions

View File

@ -134,6 +134,8 @@
title: AltDiffusion
- local: api/pipelines/audio_diffusion
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/cycle_diffusion
title: Cycle Diffusion
- local: api/pipelines/dance_diffusion

View File

@ -0,0 +1,82 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AudioLDM
## Overview
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://arxiv.org/abs/2301.12503) by Haohe Liu et al.
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)
latents. AudioLDM takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional
sound effects, human speech and music.
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found [here](https://github.com/haoheliu/AudioLDM).
## Text-to-Audio
The [`AudioLDMPipeline`] can be used to load pre-trained weights from [cvssp/audioldm](https://huggingface.co/cvssp/audioldm) and generate text-conditional audio outputs:
```python
from diffusers import AudioLDMPipeline
import torch
import scipy
repo_id = "cvssp/audioldm"
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
# save the audio sample as a .wav file
scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
```
### Tips
Prompts:
* Descriptive prompt inputs work best: you can use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g., "water stream in a forest" instead of "stream").
* It's best to use general terms like 'cat' or 'dog' instead of specific names or abstract objects that the model may not be familiar with.
Inference:
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument: higher steps give higher quality audio at the expense of slower inference.
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
### How to load and use different schedulers
The AudioLDM pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers
that can be used with the AudioLDM pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
[`EulerAncestralDiscreteScheduler`] etc. We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest
scheduler there is.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`]
method, or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the
[`DPMSolverMultistepScheduler`], you can do the following:
```python
>>> from diffusers import AudioLDMPipeline, DPMSolverMultistepScheduler
>>> import torch
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16)
>>> pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained("cvssp/audioldm", subfolder="scheduler")
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", scheduler=dpm_scheduler, torch_dtype=torch.float16)
```
## AudioLDMPipeline
[[autodoc]] AudioLDMPipeline
- all
- __call__

File diff suppressed because it is too large Load Diff

View File

@ -112,6 +112,7 @@ else:
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AudioLDMPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
PaintByExamplePipeline,

View File

@ -86,13 +86,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
@ -106,6 +107,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
"""
_supports_gradient_checkpointing = True
@ -135,7 +138,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
@ -149,6 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
):
super().__init__()
@ -175,6 +179,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
@ -228,6 +237,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
@ -240,6 +255,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@ -252,12 +278,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
@ -272,12 +298,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
@ -287,11 +313,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
@ -307,6 +333,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@ -330,12 +357,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
@ -571,7 +598,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)

View File

@ -44,6 +44,7 @@ except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline

View File

@ -0,0 +1,17 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AudioLDMPipeline,
)
else:
from .pipeline_audioldm import AudioLDMPipeline

View File

@ -0,0 +1,601 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AudioLDMPipeline
>>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "A hammer hitting a wooden surface"
>>> audio = pipe(prompt).audio[0]
```
"""
class AudioLDMPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-audio generation using AudioLDM.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations.
text_encoder ([`ClapTextModelWithProjection`]):
Frozen text-encoder. AudioLDM uses the text portion of
[CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection),
specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
tokenizer ([`PreTrainedTokenizer`]):
Tokenizer of class
[RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer).
unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
vocoder ([`SpeechT5HifiGan`]):
Vocoder of class
[SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan).
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: ClapTextModelWithProjection,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vocoder: SpeechT5HifiGan,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
vocoder=vocoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and vocoder have their state dicts saved to CPU and then are moved to a `torch.device('meta')
and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.vocoder]:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_waveforms_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device (`torch.device`):
torch device
num_waveforms_per_prompt (`int`):
number of waveforms that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLAP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask.to(device),
)
prompt_embeds = prompt_embeds.text_embeds
# additional L_2 normalization over each hidden-state
prompt_embeds = F.normalize(prompt_embeds, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
(
bs_embed,
seq_len,
) = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_input.input_ids.to(device)
attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input_ids,
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds.text_embeds
# additional L_2 normalization over each hidden-state
negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
mel_spectrogram = self.vae.decode(latents).sample
return mel_spectrogram
def mel_spectrogram_to_waveform(self, mel_spectrogram):
if mel_spectrogram.dim() == 4:
mel_spectrogram = mel_spectrogram.squeeze(1)
waveform = self.vocoder(mel_spectrogram)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
waveform = waveform.cpu()
return waveform
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
audio_length_in_s,
vocoder_upsample_factor,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
if audio_length_in_s < min_audio_length_in_s:
raise ValueError(
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
f"is {audio_length_in_s}."
)
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
raise ValueError(
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
f"{self.vae_scale_factor}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
self.vocoder.config.model_in_dim // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
audio_length_in_s: Optional[float] = None,
num_inference_steps: int = 10,
guidance_scale: float = 2.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_waveforms_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
output_type: Optional[str] = "np",
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`.
instead.
audio_length_in_s (`int`, *optional*, defaults to 5.12):
The length of the generated audio sample in seconds.
num_inference_steps (`int`, *optional*, defaults to 10):
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 2.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`,
usually at the expense of lower sound quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
The number of waveforms to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generate image. Choose between:
- `"np"`: Return Numpy `np.ndarray` objects.
- `"pt"`: Return PyTorch `torch.Tensor` objects.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated audios.
"""
# 0. Convert audio input length from seconds to spectrogram height
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
if audio_length_in_s is None:
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
height = int(audio_length_in_s / vocoder_upsample_factor)
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
if height % self.vae_scale_factor != 0:
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
logger.info(
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
f"denoising process."
)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
audio_length_in_s,
vocoder_upsample_factor,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_waveforms_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt,
num_channels_latents,
height,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=None,
class_labels=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
mel_spectrogram = self.decode_latents(latents)
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
audio = audio[:, :original_waveform_length]
if output_type == "np":
audio = audio.numpy()
if not return_dict:
return (audio,)
return AudioPipelineOutput(audios=audio)

View File

@ -167,13 +167,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
@ -187,6 +188,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
"""
_supports_gradient_checkpointing = True
@ -221,7 +224,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
@ -235,6 +238,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
):
super().__init__()
@ -265,6 +269,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
@ -318,6 +328,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
@ -330,6 +346,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@ -342,12 +369,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
@ -362,12 +389,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if mid_block_type == "UNetMidBlockFlatCrossAttn":
self.mid_block = UNetMidBlockFlatCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
@ -377,11 +404,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
@ -397,6 +424,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@ -420,12 +448,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
@ -661,7 +689,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)

View File

@ -32,6 +32,21 @@ class AltDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AudioLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@ -199,6 +199,74 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_simple_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
init_dict["class_embed_type"] = "simple_projection"
init_dict["projection_class_embeddings_input_dim"] = sample_size
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_class_embeddings_concat(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
init_dict["class_embed_type"] = "simple_projection"
init_dict["projection_class_embeddings_input_dim"] = sample_size
init_dict["class_embeddings_concat"] = True
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@ -103,6 +103,19 @@ UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
"audio_length_in_s",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
"cross_attention_kwargs",
]
)
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])

View File

View File

@ -0,0 +1,416 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
ClapTextConfig,
ClapTextModelWithProjection,
RobertaTokenizer,
SpeechT5HifiGan,
SpeechT5HifiGanConfig,
)
from diffusers import (
AudioLDMPipeline,
AutoencoderKL,
DDIMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from ...pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AudioLDMPipeline
params = TEXT_TO_AUDIO_PARAMS
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"num_waveforms_per_prompt",
"generator",
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=(32, 64),
class_embed_type="simple_projection",
projection_class_embeddings_input_dim=32,
class_embeddings_concat=True,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=1,
out_channels=1,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = ClapTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
projection_dim=32,
)
text_encoder = ClapTextModelWithProjection(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
vocoder_config = SpeechT5HifiGanConfig(
model_in_dim=8,
sampling_rate=16000,
upsample_initial_channel=16,
upsample_rates=[2, 2],
upsample_kernel_sizes=[4, 4],
resblock_kernel_sizes=[3, 7],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
normalize_before=False,
)
vocoder = SpeechT5HifiGan(vocoder_config)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
}
return inputs
def test_audioldm_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = audioldm_pipe(**inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) == 256
audio_slice = audio[:10]
expected_slice = np.array(
[-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033]
)
assert np.abs(audio_slice - expected_slice).max() < 1e-2
def test_audioldm_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = audioldm_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = audioldm_pipe.tokenizer(
prompt,
padding="max_length",
max_length=audioldm_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
prompt_embeds = audioldm_pipe.text_encoder(
text_inputs,
)
prompt_embeds = prompt_embeds.text_embeds
# additional L_2 normalization over each hidden-state
prompt_embeds = F.normalize(prompt_embeds, dim=-1)
inputs["prompt_embeds"] = prompt_embeds
# forward
output = audioldm_pipe(**inputs)
audio_2 = output.audios[0]
assert np.abs(audio_1 - audio_2).max() < 1e-2
def test_audioldm_negative_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = audioldm_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
embeds = []
for p in [prompt, negative_prompt]:
text_inputs = audioldm_pipe.tokenizer(
p,
padding="max_length",
max_length=audioldm_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
text_embeds = audioldm_pipe.text_encoder(
text_inputs,
)
text_embeds = text_embeds.text_embeds
# additional L_2 normalization over each hidden-state
text_embeds = F.normalize(text_embeds, dim=-1)
embeds.append(text_embeds)
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
# forward
output = audioldm_pipe(**inputs)
audio_2 = output.audios[0]
assert np.abs(audio_1 - audio_2).max() < 1e-2
def test_audioldm_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "egg cracking"
output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) == 256
audio_slice = audio[:10]
expected_slice = np.array(
[-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032]
)
assert np.abs(audio_slice - expected_slice).max() < 1e-2
def test_audioldm_num_waveforms_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(device)
audioldm_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
# test num_waveforms_per_prompt=1 (default)
audios = audioldm_pipe(prompt, num_inference_steps=2).audios
assert audios.shape == (1, 256)
# test num_waveforms_per_prompt=1 (default) for batch of prompts
batch_size = 2
audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
assert audios.shape == (batch_size, 256)
# test num_waveforms_per_prompt for single prompt
num_waveforms_per_prompt = 2
audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
assert audios.shape == (num_waveforms_per_prompt, 256)
# test num_waveforms_per_prompt for batch of prompts
batch_size = 2
audios = audioldm_pipe(
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
def test_audioldm_audio_length_in_s(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
inputs = self.get_dummy_inputs(device)
output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) / vocoder_sampling_rate == 0.016
output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) / vocoder_sampling_rate == 0.032
def test_audioldm_vocoder_model_in_dim(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDMPipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
prompt = ["hey"]
output = audioldm_pipe(prompt, num_inference_steps=1)
audio_shape = output.audios.shape
assert audio_shape == (1, 256)
config = audioldm_pipe.vocoder.config
config.model_in_dim *= 2
audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
output = audioldm_pipe(prompt, num_inference_steps=1)
audio_shape = output.audios.shape
# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
assert audio_shape == (1, 256)
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
@slow
# @require_torch_gpu
class AudioLDMPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 2.5,
}
return inputs
def test_audioldm(self):
audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81920
audio_slice = audio[77230:77240]
expected_slice = np.array(
[-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315]
)
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-2
def test_audioldm_lms(self):
audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81920
audio_slice = audio[27780:27790]
expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212])
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-2