riffusion-inference/riffusion/spectrogram_converter.py

205 lines
7.0 KiB
Python

import warnings
import numpy as np
import pydub
import torch
import torchaudio
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import audio_util, torch_util
class SpectrogramConverter:
"""
Convert between audio segments and spectrogram tensors using torchaudio.
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
used in some functions is "mel amplitudes".
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
returns the amplitude, because the phase is chaotic and hard to learn. The function
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
approximates the phase information using the Griffin-Lim algorithm.
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
equal to the number of channels in the input audio segment.
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
For more information, see https://pytorch.org/audio/stable/transforms.html
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = torch_util.check_device(device)
if device.lower().startswith("mps"):
warnings.warn(
"WARNING: MPS does not support audio operations, falling back to CPU for them",
stacklevel=2,
)
self.device = "cpu"
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
self.spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=params.n_fft,
hop_length=params.hop_length,
win_length=params.win_length,
pad=0,
window_fn=torch.hann_window,
power=None,
normalized=False,
wkwargs=None,
center=True,
pad_mode="reflect",
onesided=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
n_fft=params.n_fft,
n_iter=params.num_griffin_lim_iters,
win_length=params.win_length,
hop_length=params.hop_length,
window_fn=torch.hann_window,
power=1.0,
wkwargs=None,
momentum=0.99,
length=None,
rand_init=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
self.mel_scaler = torchaudio.transforms.MelScale(
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
n_stft=params.n_fft // 2 + 1,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=params.n_fft // 2 + 1,
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
max_iter=params.max_mel_iters,
tolerance_loss=1e-5,
tolerance_change=1e-8,
sgdargs=None,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
def spectrogram_from_audio(
self,
audio: pydub.AudioSegment,
) -> np.ndarray:
"""
Compute a spectrogram from an audio segment.
Args:
audio: Audio segment which must match the sample rate of the params
Returns:
spectrogram: (channel, frequency, time)
"""
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
# Get the samples as a numpy array in (batch, samples) shape
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
# Convert to floats if necessary
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
waveform_tensor = torch.from_numpy(waveform).to(self.device)
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
return amplitudes_mel.cpu().numpy()
def audio_from_spectrogram(
self,
spectrogram: np.ndarray,
apply_filters: bool = True,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram.
Args:
spectrogram: (batch, frequency, time)
apply_filters: Post-process with normalization and compression
Returns:
audio: Audio segment with channels equal to the batch dimension
"""
# Move to device
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
# Reconstruct the waveform
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
# Convert to audio segment
segment = audio_util.audio_from_waveform(
samples=waveform.cpu().numpy(),
sample_rate=self.p.sample_rate,
# Normalize the waveform to the range [-1, 1]
normalize=True,
)
# Optionally apply post-processing filters
if apply_filters:
segment = audio_util.apply_filters(
segment,
compression=False,
)
return segment
def mel_amplitudes_from_waveform(
self,
waveform: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to compute Mel-scale amplitudes from a waveform.
Args:
waveform: (batch, samples)
Returns:
amplitudes_mel: (batch, frequency, time)
"""
# Compute the complex-valued spectrogram
spectrogram_complex = self.spectrogram_func(waveform)
# Take the magnitude
amplitudes = torch.abs(spectrogram_complex)
# Convert to mel scale
return self.mel_scaler(amplitudes)
def waveform_from_mel_amplitudes(
self,
amplitudes_mel: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
Args:
amplitudes_mel: (batch, frequency, time)
Returns:
waveform: (batch, samples)
"""
# Convert from mel scale to linear
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
# Run the approximate algorithm to compute the phase and recover the waveform
return self.inverse_spectrogram_func(amplitudes_linear)