Add SpectrogramConverter class and test
This class is a helper to convert between spectrogram tensors and audio. Topic: clean_rewrite
This commit is contained in:
parent
3ab5087c7a
commit
7d0e08711c
|
@ -0,0 +1,201 @@
|
|||
import numpy as np
|
||||
import pydub
|
||||
import torch
|
||||
import torchaudio
|
||||
import warnings
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import audio_util
|
||||
from riffusion.util import torch_util
|
||||
|
||||
|
||||
class SpectrogramConverter:
|
||||
"""
|
||||
Convert between audio segments and spectrogram tensors using torchaudio.
|
||||
|
||||
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
|
||||
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
|
||||
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
|
||||
used in some functions is "mel amplitudes".
|
||||
|
||||
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
|
||||
returns the amplitude, because the phase is chaotic and hard to learn. The function
|
||||
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
|
||||
approximates the phase information using the Griffin-Lim algorithm.
|
||||
|
||||
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
|
||||
equal to the number of channels in the input audio segment.
|
||||
|
||||
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
|
||||
|
||||
For more information, see https://pytorch.org/audio/stable/transforms.html
|
||||
"""
|
||||
|
||||
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
self.p = params
|
||||
|
||||
self.device = torch_util.check_device(device)
|
||||
|
||||
if device.lower().startswith("mps"):
|
||||
warnings.warn(
|
||||
"WARNING: MPS does not support audio operations, falling back to CPU for them",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.device = "cpu"
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
|
||||
self.spectrogram_func = torchaudio.transforms.Spectrogram(
|
||||
n_fft=params.n_fft,
|
||||
hop_length=params.hop_length,
|
||||
win_length=params.win_length,
|
||||
pad=0,
|
||||
window_fn=torch.hann_window,
|
||||
power=None,
|
||||
normalized=False,
|
||||
wkwargs=None,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
|
||||
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
|
||||
n_fft=params.n_fft,
|
||||
n_iter=params.num_griffin_lim_iters,
|
||||
win_length=params.win_length,
|
||||
hop_length=params.hop_length,
|
||||
window_fn=torch.hann_window,
|
||||
power=1.0,
|
||||
wkwargs=None,
|
||||
momentum=0.99,
|
||||
length=None,
|
||||
rand_init=True,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
|
||||
self.mel_scaler = torchaudio.transforms.MelScale(
|
||||
n_mels=params.num_frequencies,
|
||||
sample_rate=params.sample_rate,
|
||||
f_min=params.min_frequency,
|
||||
f_max=params.max_frequency,
|
||||
n_stft=params.n_fft // 2 + 1,
|
||||
norm=params.mel_scale_norm,
|
||||
mel_scale=params.mel_scale_type,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
|
||||
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
|
||||
n_stft=params.n_fft // 2 + 1,
|
||||
n_mels=params.num_frequencies,
|
||||
sample_rate=params.sample_rate,
|
||||
f_min=params.min_frequency,
|
||||
f_max=params.max_frequency,
|
||||
max_iter=params.max_mel_iters,
|
||||
tolerance_loss=1e-5,
|
||||
tolerance_change=1e-8,
|
||||
sgdargs=None,
|
||||
norm=params.mel_scale_norm,
|
||||
mel_scale=params.mel_scale_type,
|
||||
).to(self.device)
|
||||
|
||||
def spectrogram_from_audio(
|
||||
self,
|
||||
audio: pydub.AudioSegment,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram from an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Audio segment which must match the sample rate of the params
|
||||
|
||||
Returns:
|
||||
spectrogram: (channel, frequency, time)
|
||||
"""
|
||||
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
|
||||
|
||||
# Get the samples as a numpy array in (batch, samples) shape
|
||||
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
|
||||
|
||||
# Convert to floats if necessary
|
||||
if waveform.dtype != np.float32:
|
||||
waveform = waveform.astype(np.float32)
|
||||
|
||||
waveform_tensor = torch.from_numpy(waveform).to(self.device)
|
||||
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
|
||||
return amplitudes_mel.cpu().numpy()
|
||||
|
||||
def audio_from_spectrogram(
|
||||
self,
|
||||
spectrogram: np.ndarray,
|
||||
apply_filters: bool = True,
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Reconstruct an audio segment from a spectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram: (batch, frequency, time)
|
||||
apply_filters: Post-process with normalization and compression
|
||||
|
||||
Returns:
|
||||
audio: Audio segment with channels equal to the batch dimension
|
||||
"""
|
||||
# Move to device
|
||||
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
|
||||
|
||||
# Reconstruct the waveform
|
||||
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
|
||||
|
||||
# Convert to audio segment
|
||||
segment = audio_util.audio_from_waveform(
|
||||
samples=waveform.cpu().numpy(),
|
||||
sample_rate=self.p.sample_rate,
|
||||
# Normalize the waveform to the range [-1, 1]
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
# Optionally apply post-processing filters
|
||||
if apply_filters:
|
||||
segment = audio_util.apply_filters(segment)
|
||||
|
||||
return segment
|
||||
|
||||
def mel_amplitudes_from_waveform(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Torch-only function to compute Mel-scale amplitudes from a waveform.
|
||||
|
||||
Args:
|
||||
waveform: (batch, samples)
|
||||
|
||||
Returns:
|
||||
amplitudes_mel: (batch, frequency, time)
|
||||
"""
|
||||
# Compute the complex-valued spectrogram
|
||||
spectrogram_complex = self.spectrogram_func(waveform)
|
||||
|
||||
# Take the magnitude
|
||||
amplitudes = torch.abs(spectrogram_complex)
|
||||
|
||||
# Convert to mel scale
|
||||
return self.mel_scaler(amplitudes)
|
||||
|
||||
def waveform_from_mel_amplitudes(
|
||||
self,
|
||||
amplitudes_mel: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
|
||||
|
||||
Args:
|
||||
amplitudes_mel: (batch, frequency, time)
|
||||
|
||||
Returns:
|
||||
waveform: (batch, samples)
|
||||
"""
|
||||
# Convert from mel scale to linear
|
||||
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
|
||||
|
||||
# Run the approximate algorithm to compute the phase and recover the waveform
|
||||
return self.inverse_spectrogram_func(amplitudes_linear)
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
FFT tools to analyze frequency content of audio segments. This is not code for
|
||||
dealing with spectrogram images, but for analysis of waveforms.
|
||||
"""
|
||||
import struct
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import pydub
|
||||
from scipy.fft import rfft, rfftfreq
|
||||
|
||||
|
||||
def plot_ffts(
|
||||
segments: T.Dict[str, pydub.AudioSegment],
|
||||
title: str = "FFT",
|
||||
min_frequency: float = 20,
|
||||
max_frequency: float = 20000,
|
||||
) -> None:
|
||||
"""
|
||||
Plot an FFT analysis of the given audio segments.
|
||||
"""
|
||||
ffts = {name: compute_fft(seg) for name, seg in segments.items()}
|
||||
|
||||
fig = go.Figure(
|
||||
data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
|
||||
layout={"title": title},
|
||||
)
|
||||
fig.update_xaxes(
|
||||
range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
|
||||
type="log",
|
||||
title="Frequency",
|
||||
)
|
||||
fig.update_yaxes(title="Value")
|
||||
fig.show()
|
||||
|
||||
|
||||
def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute the FFT of the given audio segment as a mono signal.
|
||||
|
||||
Returns:
|
||||
frequencies: FFT computed frequencies
|
||||
amplitudes: Amplitudes of each frequency
|
||||
"""
|
||||
# Convert to mono if needed.
|
||||
if sound.channels > 1:
|
||||
sound = sound.set_channels(1)
|
||||
|
||||
sample_rate = sound.frame_rate
|
||||
|
||||
num_samples = int(sound.frame_count())
|
||||
samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
|
||||
|
||||
fft_values = rfft(samples)
|
||||
amplitudes = np.abs(fft_values)
|
||||
|
||||
frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
|
||||
|
||||
return frequencies, amplitudes
|
|
@ -0,0 +1,86 @@
|
|||
import dataclasses
|
||||
import typing as T
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import fft_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SpectrogramConverterTest(TestCase):
|
||||
"""
|
||||
Test going from audio to spectrogram to audio, without converting to
|
||||
an image, to check quality loss of the reconstruction.
|
||||
|
||||
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
||||
and by plotting their FFTs.
|
||||
"""
|
||||
|
||||
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
|
||||
|
||||
def test_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono if desired
|
||||
use_stereo = False
|
||||
if use_stereo:
|
||||
assert segment.channels == 2
|
||||
else:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Define named sets of parameters
|
||||
param_sets: T.Dict[str, SpectrogramParams] = {}
|
||||
|
||||
param_sets["default"] = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=use_stereo,
|
||||
step_size_ms=10,
|
||||
min_frequency=20,
|
||||
max_frequency=20000,
|
||||
num_frequencies=512,
|
||||
)
|
||||
|
||||
if self.DEBUG:
|
||||
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
||||
param_sets["default"],
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segments: T.Dict[str, pydub.AudioSegment] = {
|
||||
"original": segment,
|
||||
}
|
||||
for name, params in param_sets.items():
|
||||
converter = SpectrogramConverter(params=params, device=self.DEVICE)
|
||||
spectrogram = converter.spectrogram_from_audio(segment)
|
||||
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
|
||||
|
||||
# Save segments to disk
|
||||
for name, segment in segments.items():
|
||||
audio_out = output_dir / f"{name}.wav"
|
||||
segment.export(audio_out, format="wav")
|
||||
print(f"Saved {audio_out}")
|
||||
|
||||
# Check params
|
||||
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
||||
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
||||
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
||||
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
||||
|
||||
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
||||
|
||||
# If debugging, load up a browser tab plotting the FFTs
|
||||
if self.DEBUG:
|
||||
fft_util.plot_ffts(segments)
|
Loading…
Reference in New Issue