riffusion-inference/test/spectrogram_converter_test.py

87 lines
2.9 KiB
Python
Raw Normal View History

import dataclasses
import typing as T
import pydub
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import fft_util
from .test_case import TestCase
class SpectrogramConverterTest(TestCase):
"""
Test going from audio to spectrogram to audio, without converting to
an image, to check quality loss of the reconstruction.
This test allows comparing multiple sets of spectrogram params by listening to output audio
and by plotting their FFTs.
"""
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
def test_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono if desired
use_stereo = False
if use_stereo:
assert segment.channels == 2
else:
segment = segment.set_channels(1)
# Define named sets of parameters
param_sets: T.Dict[str, SpectrogramParams] = {}
param_sets["default"] = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=use_stereo,
step_size_ms=10,
min_frequency=20,
max_frequency=20000,
num_frequencies=512,
)
if self.DEBUG:
param_sets["freq_0_to_10k"] = dataclasses.replace(
param_sets["default"],
min_frequency=0,
max_frequency=10000,
)
segments: T.Dict[str, pydub.AudioSegment] = {
"original": segment,
}
for name, params in param_sets.items():
converter = SpectrogramConverter(params=params, device=self.DEVICE)
spectrogram = converter.spectrogram_from_audio(segment)
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
# Save segments to disk
for name, segment in segments.items():
audio_out = output_dir / f"{name}.wav"
segment.export(audio_out, format="wav")
print(f"Saved {audio_out}")
# Check params
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
self.assertEqual(segments["original"].channels, segments["default"].channels)
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
# If debugging, load up a browser tab plotting the FFTs
if self.DEBUG:
fft_util.plot_ffts(segments)