Add SpectrogramConverter class and test

This class is a helper to convert between spectrogram tensors and
audio.

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:20:56 -08:00
parent 3ab5087c7a
commit 7d0e08711c
3 changed files with 347 additions and 0 deletions

View File

@ -0,0 +1,201 @@
import numpy as np
import pydub
import torch
import torchaudio
import warnings
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import audio_util
from riffusion.util import torch_util
class SpectrogramConverter:
"""
Convert between audio segments and spectrogram tensors using torchaudio.
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
used in some functions is "mel amplitudes".
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
returns the amplitude, because the phase is chaotic and hard to learn. The function
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
approximates the phase information using the Griffin-Lim algorithm.
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
equal to the number of channels in the input audio segment.
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
For more information, see https://pytorch.org/audio/stable/transforms.html
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = torch_util.check_device(device)
if device.lower().startswith("mps"):
warnings.warn(
"WARNING: MPS does not support audio operations, falling back to CPU for them",
stacklevel=2,
)
self.device = "cpu"
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
self.spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=params.n_fft,
hop_length=params.hop_length,
win_length=params.win_length,
pad=0,
window_fn=torch.hann_window,
power=None,
normalized=False,
wkwargs=None,
center=True,
pad_mode="reflect",
onesided=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
n_fft=params.n_fft,
n_iter=params.num_griffin_lim_iters,
win_length=params.win_length,
hop_length=params.hop_length,
window_fn=torch.hann_window,
power=1.0,
wkwargs=None,
momentum=0.99,
length=None,
rand_init=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
self.mel_scaler = torchaudio.transforms.MelScale(
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
n_stft=params.n_fft // 2 + 1,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=params.n_fft // 2 + 1,
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
max_iter=params.max_mel_iters,
tolerance_loss=1e-5,
tolerance_change=1e-8,
sgdargs=None,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
def spectrogram_from_audio(
self,
audio: pydub.AudioSegment,
) -> np.ndarray:
"""
Compute a spectrogram from an audio segment.
Args:
audio: Audio segment which must match the sample rate of the params
Returns:
spectrogram: (channel, frequency, time)
"""
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
# Get the samples as a numpy array in (batch, samples) shape
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
# Convert to floats if necessary
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
waveform_tensor = torch.from_numpy(waveform).to(self.device)
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
return amplitudes_mel.cpu().numpy()
def audio_from_spectrogram(
self,
spectrogram: np.ndarray,
apply_filters: bool = True,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram.
Args:
spectrogram: (batch, frequency, time)
apply_filters: Post-process with normalization and compression
Returns:
audio: Audio segment with channels equal to the batch dimension
"""
# Move to device
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
# Reconstruct the waveform
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
# Convert to audio segment
segment = audio_util.audio_from_waveform(
samples=waveform.cpu().numpy(),
sample_rate=self.p.sample_rate,
# Normalize the waveform to the range [-1, 1]
normalize=True,
)
# Optionally apply post-processing filters
if apply_filters:
segment = audio_util.apply_filters(segment)
return segment
def mel_amplitudes_from_waveform(
self,
waveform: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to compute Mel-scale amplitudes from a waveform.
Args:
waveform: (batch, samples)
Returns:
amplitudes_mel: (batch, frequency, time)
"""
# Compute the complex-valued spectrogram
spectrogram_complex = self.spectrogram_func(waveform)
# Take the magnitude
amplitudes = torch.abs(spectrogram_complex)
# Convert to mel scale
return self.mel_scaler(amplitudes)
def waveform_from_mel_amplitudes(
self,
amplitudes_mel: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
Args:
amplitudes_mel: (batch, frequency, time)
Returns:
waveform: (batch, samples)
"""
# Convert from mel scale to linear
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
# Run the approximate algorithm to compute the phase and recover the waveform
return self.inverse_spectrogram_func(amplitudes_linear)

View File

@ -0,0 +1,60 @@
"""
FFT tools to analyze frequency content of audio segments. This is not code for
dealing with spectrogram images, but for analysis of waveforms.
"""
import struct
import typing as T
import numpy as np
import plotly.graph_objects as go
import pydub
from scipy.fft import rfft, rfftfreq
def plot_ffts(
segments: T.Dict[str, pydub.AudioSegment],
title: str = "FFT",
min_frequency: float = 20,
max_frequency: float = 20000,
) -> None:
"""
Plot an FFT analysis of the given audio segments.
"""
ffts = {name: compute_fft(seg) for name, seg in segments.items()}
fig = go.Figure(
data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
layout={"title": title},
)
fig.update_xaxes(
range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
type="log",
title="Frequency",
)
fig.update_yaxes(title="Value")
fig.show()
def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
"""
Compute the FFT of the given audio segment as a mono signal.
Returns:
frequencies: FFT computed frequencies
amplitudes: Amplitudes of each frequency
"""
# Convert to mono if needed.
if sound.channels > 1:
sound = sound.set_channels(1)
sample_rate = sound.frame_rate
num_samples = int(sound.frame_count())
samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
fft_values = rfft(samples)
amplitudes = np.abs(fft_values)
frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
return frequencies, amplitudes

View File

@ -0,0 +1,86 @@
import dataclasses
import typing as T
import pydub
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import fft_util
from .test_case import TestCase
class SpectrogramConverterTest(TestCase):
"""
Test going from audio to spectrogram to audio, without converting to
an image, to check quality loss of the reconstruction.
This test allows comparing multiple sets of spectrogram params by listening to output audio
and by plotting their FFTs.
"""
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
def test_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono if desired
use_stereo = False
if use_stereo:
assert segment.channels == 2
else:
segment = segment.set_channels(1)
# Define named sets of parameters
param_sets: T.Dict[str, SpectrogramParams] = {}
param_sets["default"] = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=use_stereo,
step_size_ms=10,
min_frequency=20,
max_frequency=20000,
num_frequencies=512,
)
if self.DEBUG:
param_sets["freq_0_to_10k"] = dataclasses.replace(
param_sets["default"],
min_frequency=0,
max_frequency=10000,
)
segments: T.Dict[str, pydub.AudioSegment] = {
"original": segment,
}
for name, params in param_sets.items():
converter = SpectrogramConverter(params=params, device=self.DEVICE)
spectrogram = converter.spectrogram_from_audio(segment)
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
# Save segments to disk
for name, segment in segments.items():
audio_out = output_dir / f"{name}.wav"
segment.export(audio_out, format="wav")
print(f"Saved {audio_out}")
# Check params
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
self.assertEqual(segments["original"].channels, segments["default"].channels)
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
# If debugging, load up a browser tab plotting the FFTs
if self.DEBUG:
fft_util.plot_ffts(segments)