From 4c78e1a2281749b0fe0a30c7a8fb5a8196471b28 Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Mon, 26 Dec 2022 17:21:42 -0800 Subject: [PATCH] Add SpectrogramImageConverter and test This class converts between spectrogram images and audio. Uses SpectrogramConverter internally, which only deals with tensors. Topic: clean_rewrite --- riffusion/spectrogram_image_converter.py | 91 ++++++++++++++++++++++ test/spectrogram_image_converter_test.py | 97 ++++++++++++++++++++++++ 2 files changed, 188 insertions(+) create mode 100644 riffusion/spectrogram_image_converter.py create mode 100644 test/spectrogram_image_converter_test.py diff --git a/riffusion/spectrogram_image_converter.py b/riffusion/spectrogram_image_converter.py new file mode 100644 index 0000000..198a76f --- /dev/null +++ b/riffusion/spectrogram_image_converter.py @@ -0,0 +1,91 @@ +import numpy as np +from PIL import Image +import pydub + +from riffusion.spectrogram_converter import SpectrogramConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.util import image_util + + +class SpectrogramImageConverter: + """ + Convert between spectrogram images and audio segments. + + This is a wrapper around SpectrogramConverter that additionally converts from spectrograms + to images and back. The real audio processing lives in SpectrogramConverter. + """ + + def __init__(self, params: SpectrogramParams, device: str = "cuda"): + self.p = params + self.device = device + self.converter = SpectrogramConverter(params=params, device=device) + + def spectrogram_image_from_audio( + self, + segment: pydub.AudioSegment, + ) -> Image.Image: + """ + Compute a spectrogram image from an audio segment. + + Args: + segment: Audio segment to convert + + Returns: + Spectrogram image (in pillow format) + """ + assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch" + + if self.p.stereo: + if segment.channels == 1: + print("WARNING: Mono audio but stereo=True, cloning channel") + segment = segment.set_channels(2) + elif segment.channels > 2: + print("WARNING: Multi channel audio, reducing to stereo") + segment = segment.set_channels(2) + else: + if segment.channels > 1: + print("WARNING: Stereo audio but stereo=False, setting to mono") + segment = segment.set_channels(1) + + spectrogram = self.converter.spectrogram_from_audio(segment) + + image = image_util.image_from_spectrogram( + spectrogram, + power=self.p.power_for_image, + ) + + # Store conversion params in exif metadata of the image + exif_data = self.p.to_exif() + exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram)) + exif = image.getexif() + exif.update(exif_data.items()) + + return image + + def audio_from_spectrogram_image( + self, + image: Image.Image, + apply_filters: bool = True, + max_value: float = 30e6, + ) -> pydub.AudioSegment: + """ + Reconstruct an audio segment from a spectrogram image. + + Args: + image: Spectrogram image (in pillow format) + apply_filters: Apply post-processing to improve the reconstructed audio + max_value: Scaled max amplitude of the spectrogram. Shouldn't matter. + """ + spectrogram = image_util.spectrogram_from_image( + image, + max_value=max_value, + power=self.p.power_for_image, + stereo=self.p.stereo, + ) + + segment = self.converter.audio_from_spectrogram( + spectrogram, + apply_filters=apply_filters, + ) + + return segment diff --git a/test/spectrogram_image_converter_test.py b/test/spectrogram_image_converter_test.py new file mode 100644 index 0000000..78223be --- /dev/null +++ b/test/spectrogram_image_converter_test.py @@ -0,0 +1,97 @@ +import dataclasses +import typing as T + +from PIL import Image +import pydub + +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.util import fft_util + +from .test_case import TestCase + + +class SpectrogramImageConverterTest(TestCase): + """ + Test going from audio to spectrogram images to audio, testing the quality loss of the + end-to-end pipeline. + + This test allows comparing multiple sets of spectrogram params by listening to output audio + and by plotting their FFTs. + + See spectrogram_converter_test.py for a similar test that does not convert to images. + """ + + def test_round_trip(self) -> None: + audio_path = ( + self.TEST_DATA_PATH + / "tired_traveler" + / "clips" + / "clip_2_start_103694_ms_duration_5678_ms.wav" + ) + output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_") + + # Load up the audio file + segment = pydub.AudioSegment.from_file(audio_path) + + # Convert to mono if desired + use_stereo = False + if use_stereo: + assert segment.channels == 2 + else: + segment = segment.set_channels(1) + + # Define named sets of parameters + param_sets: T.Dict[str, SpectrogramParams] = {} + + param_sets["default"] = SpectrogramParams( + sample_rate=segment.frame_rate, + stereo=use_stereo, + step_size_ms=10, + min_frequency=20, + max_frequency=20000, + num_frequencies=512, + ) + + if self.DEBUG: + param_sets["freq_0_to_10k"] = dataclasses.replace( + param_sets["default"], + min_frequency=0, + max_frequency=10000, + ) + + segments: T.Dict[str, pydub.AudioSegment] = { + "original": segment, + } + images: T.Dict[str, Image.Image] = {} + for name, params in param_sets.items(): + converter = SpectrogramImageConverter(params=params, device=self.DEVICE) + images[name] = converter.spectrogram_image_from_audio(segment) + segments[name] = converter.audio_from_spectrogram_image( + image=images[name], + apply_filters=True, + ) + + # Save images to disk + for name, image in images.items(): + image_out = output_dir / f"{name}.png" + image.save(image_out, exif=image.getexif(), format="PNG") + print(f"Saved {image_out}") + + # Save segments to disk + for name, segment in segments.items(): + audio_out = output_dir / f"{name}.wav" + segment.export(audio_out, format="wav") + print(f"Saved {audio_out}") + + # Check params + self.assertEqual(segments["default"].channels, 2 if use_stereo else 1) + self.assertEqual(segments["original"].channels, segments["default"].channels) + self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate) + self.assertEqual(segments["original"].sample_width, segments["default"].sample_width) + + # TODO(hayk): Test something more rigorous about the quality of the reconstruction. + + # If debugging, load up a browser tab plotting the FFTs + if self.DEBUG: + fft_util.plot_ffts(segments)