import dataclasses import typing as T import pydub from riffusion.spectrogram_converter import SpectrogramConverter from riffusion.spectrogram_params import SpectrogramParams from riffusion.util import fft_util from .test_case import TestCase class SpectrogramConverterTest(TestCase): """ Test going from audio to spectrogram to audio, without converting to an image, to check quality loss of the reconstruction. This test allows comparing multiple sets of spectrogram params by listening to output audio and by plotting their FFTs. """ # TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces. def test_round_trip(self) -> None: audio_path = ( self.TEST_DATA_PATH / "tired_traveler" / "clips" / "clip_2_start_103694_ms_duration_5678_ms.wav" ) output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_") # Load up the audio file segment = pydub.AudioSegment.from_file(audio_path) # Convert to mono if desired use_stereo = False if use_stereo: assert segment.channels == 2 else: segment = segment.set_channels(1) # Define named sets of parameters param_sets: T.Dict[str, SpectrogramParams] = {} param_sets["default"] = SpectrogramParams( sample_rate=segment.frame_rate, stereo=use_stereo, step_size_ms=10, min_frequency=20, max_frequency=20000, num_frequencies=512, ) if self.DEBUG: param_sets["freq_0_to_10k"] = dataclasses.replace( param_sets["default"], min_frequency=0, max_frequency=10000, ) segments: T.Dict[str, pydub.AudioSegment] = { "original": segment, } for name, params in param_sets.items(): converter = SpectrogramConverter(params=params, device=self.DEVICE) spectrogram = converter.spectrogram_from_audio(segment) segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True) # Save segments to disk for name, segment in segments.items(): audio_out = output_dir / f"{name}.wav" segment.export(audio_out, format="wav") print(f"Saved {audio_out}") # Check params self.assertEqual(segments["default"].channels, 2 if use_stereo else 1) self.assertEqual(segments["original"].channels, segments["default"].channels) self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate) self.assertEqual(segments["original"].sample_width, segments["default"].sample_width) # TODO(hayk): Test something more rigorous about the quality of the reconstruction. # If debugging, load up a browser tab plotting the FFTs if self.DEBUG: fft_util.plot_ffts(segments)