87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
import dataclasses
|
|
import typing as T
|
|
|
|
import pydub
|
|
|
|
from riffusion.spectrogram_converter import SpectrogramConverter
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
from riffusion.util import fft_util
|
|
|
|
from .test_case import TestCase
|
|
|
|
|
|
class SpectrogramConverterTest(TestCase):
|
|
"""
|
|
Test going from audio to spectrogram to audio, without converting to
|
|
an image, to check quality loss of the reconstruction.
|
|
|
|
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
|
and by plotting their FFTs.
|
|
"""
|
|
|
|
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
|
|
|
|
def test_round_trip(self) -> None:
|
|
audio_path = (
|
|
self.TEST_DATA_PATH
|
|
/ "tired_traveler"
|
|
/ "clips"
|
|
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
|
)
|
|
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
|
|
|
|
# Load up the audio file
|
|
segment = pydub.AudioSegment.from_file(audio_path)
|
|
|
|
# Convert to mono if desired
|
|
use_stereo = False
|
|
if use_stereo:
|
|
assert segment.channels == 2
|
|
else:
|
|
segment = segment.set_channels(1)
|
|
|
|
# Define named sets of parameters
|
|
param_sets: T.Dict[str, SpectrogramParams] = {}
|
|
|
|
param_sets["default"] = SpectrogramParams(
|
|
sample_rate=segment.frame_rate,
|
|
stereo=use_stereo,
|
|
step_size_ms=10,
|
|
min_frequency=20,
|
|
max_frequency=20000,
|
|
num_frequencies=512,
|
|
)
|
|
|
|
if self.DEBUG:
|
|
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
|
param_sets["default"],
|
|
min_frequency=0,
|
|
max_frequency=10000,
|
|
)
|
|
|
|
segments: T.Dict[str, pydub.AudioSegment] = {
|
|
"original": segment,
|
|
}
|
|
for name, params in param_sets.items():
|
|
converter = SpectrogramConverter(params=params, device=self.DEVICE)
|
|
spectrogram = converter.spectrogram_from_audio(segment)
|
|
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
|
|
|
|
# Save segments to disk
|
|
for name, segment in segments.items():
|
|
audio_out = output_dir / f"{name}.wav"
|
|
segment.export(audio_out, format="wav")
|
|
print(f"Saved {audio_out}")
|
|
|
|
# Check params
|
|
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
|
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
|
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
|
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
|
|
|
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
|
|
|
# If debugging, load up a browser tab plotting the FFTs
|
|
if self.DEBUG:
|
|
fft_util.plot_ffts(segments)
|