205 lines
7.0 KiB
Python
205 lines
7.0 KiB
Python
import warnings
|
|
|
|
import numpy as np
|
|
import pydub
|
|
import torch
|
|
import torchaudio
|
|
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
from riffusion.util import audio_util, torch_util
|
|
|
|
|
|
class SpectrogramConverter:
|
|
"""
|
|
Convert between audio segments and spectrogram tensors using torchaudio.
|
|
|
|
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
|
|
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
|
|
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
|
|
used in some functions is "mel amplitudes".
|
|
|
|
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
|
|
returns the amplitude, because the phase is chaotic and hard to learn. The function
|
|
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
|
|
approximates the phase information using the Griffin-Lim algorithm.
|
|
|
|
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
|
|
equal to the number of channels in the input audio segment.
|
|
|
|
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
|
|
|
|
For more information, see https://pytorch.org/audio/stable/transforms.html
|
|
"""
|
|
|
|
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
|
self.p = params
|
|
|
|
self.device = torch_util.check_device(device)
|
|
|
|
if device.lower().startswith("mps"):
|
|
warnings.warn(
|
|
"WARNING: MPS does not support audio operations, falling back to CPU for them",
|
|
stacklevel=2,
|
|
)
|
|
self.device = "cpu"
|
|
|
|
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
|
|
self.spectrogram_func = torchaudio.transforms.Spectrogram(
|
|
n_fft=params.n_fft,
|
|
hop_length=params.hop_length,
|
|
win_length=params.win_length,
|
|
pad=0,
|
|
window_fn=torch.hann_window,
|
|
power=None,
|
|
normalized=False,
|
|
wkwargs=None,
|
|
center=True,
|
|
pad_mode="reflect",
|
|
onesided=True,
|
|
).to(self.device)
|
|
|
|
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
|
|
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
|
|
n_fft=params.n_fft,
|
|
n_iter=params.num_griffin_lim_iters,
|
|
win_length=params.win_length,
|
|
hop_length=params.hop_length,
|
|
window_fn=torch.hann_window,
|
|
power=1.0,
|
|
wkwargs=None,
|
|
momentum=0.99,
|
|
length=None,
|
|
rand_init=True,
|
|
).to(self.device)
|
|
|
|
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
|
|
self.mel_scaler = torchaudio.transforms.MelScale(
|
|
n_mels=params.num_frequencies,
|
|
sample_rate=params.sample_rate,
|
|
f_min=params.min_frequency,
|
|
f_max=params.max_frequency,
|
|
n_stft=params.n_fft // 2 + 1,
|
|
norm=params.mel_scale_norm,
|
|
mel_scale=params.mel_scale_type,
|
|
).to(self.device)
|
|
|
|
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
|
|
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
|
|
n_stft=params.n_fft // 2 + 1,
|
|
n_mels=params.num_frequencies,
|
|
sample_rate=params.sample_rate,
|
|
f_min=params.min_frequency,
|
|
f_max=params.max_frequency,
|
|
max_iter=params.max_mel_iters,
|
|
tolerance_loss=1e-5,
|
|
tolerance_change=1e-8,
|
|
sgdargs=None,
|
|
norm=params.mel_scale_norm,
|
|
mel_scale=params.mel_scale_type,
|
|
).to(self.device)
|
|
|
|
def spectrogram_from_audio(
|
|
self,
|
|
audio: pydub.AudioSegment,
|
|
) -> np.ndarray:
|
|
"""
|
|
Compute a spectrogram from an audio segment.
|
|
|
|
Args:
|
|
audio: Audio segment which must match the sample rate of the params
|
|
|
|
Returns:
|
|
spectrogram: (channel, frequency, time)
|
|
"""
|
|
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
|
|
|
|
# Get the samples as a numpy array in (batch, samples) shape
|
|
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
|
|
|
|
# Convert to floats if necessary
|
|
if waveform.dtype != np.float32:
|
|
waveform = waveform.astype(np.float32)
|
|
|
|
waveform_tensor = torch.from_numpy(waveform).to(self.device)
|
|
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
|
|
return amplitudes_mel.cpu().numpy()
|
|
|
|
def audio_from_spectrogram(
|
|
self,
|
|
spectrogram: np.ndarray,
|
|
apply_filters: bool = True,
|
|
) -> pydub.AudioSegment:
|
|
"""
|
|
Reconstruct an audio segment from a spectrogram.
|
|
|
|
Args:
|
|
spectrogram: (batch, frequency, time)
|
|
apply_filters: Post-process with normalization and compression
|
|
|
|
Returns:
|
|
audio: Audio segment with channels equal to the batch dimension
|
|
"""
|
|
# Move to device
|
|
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
|
|
|
|
# Reconstruct the waveform
|
|
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
|
|
|
|
# Convert to audio segment
|
|
segment = audio_util.audio_from_waveform(
|
|
samples=waveform.cpu().numpy(),
|
|
sample_rate=self.p.sample_rate,
|
|
# Normalize the waveform to the range [-1, 1]
|
|
normalize=True,
|
|
)
|
|
|
|
# Optionally apply post-processing filters
|
|
if apply_filters:
|
|
segment = audio_util.apply_filters(
|
|
segment,
|
|
compression=False,
|
|
)
|
|
|
|
return segment
|
|
|
|
def mel_amplitudes_from_waveform(
|
|
self,
|
|
waveform: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Torch-only function to compute Mel-scale amplitudes from a waveform.
|
|
|
|
Args:
|
|
waveform: (batch, samples)
|
|
|
|
Returns:
|
|
amplitudes_mel: (batch, frequency, time)
|
|
"""
|
|
# Compute the complex-valued spectrogram
|
|
spectrogram_complex = self.spectrogram_func(waveform)
|
|
|
|
# Take the magnitude
|
|
amplitudes = torch.abs(spectrogram_complex)
|
|
|
|
# Convert to mel scale
|
|
return self.mel_scaler(amplitudes)
|
|
|
|
def waveform_from_mel_amplitudes(
|
|
self,
|
|
amplitudes_mel: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
|
|
|
|
Args:
|
|
amplitudes_mel: (batch, frequency, time)
|
|
|
|
Returns:
|
|
waveform: (batch, samples)
|
|
"""
|
|
# Convert from mel scale to linear
|
|
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
|
|
|
|
# Run the approximate algorithm to compute the phase and recover the waveform
|
|
return self.inverse_spectrogram_func(amplitudes_linear)
|