Add SpectrogramImageConverter and test
This class converts between spectrogram images and audio. Uses SpectrogramConverter internally, which only deals with tensors. Topic: clean_rewrite
This commit is contained in:
parent
7d0e08711c
commit
4c78e1a228
|
@ -0,0 +1,91 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
|
||||
class SpectrogramImageConverter:
|
||||
"""
|
||||
Convert between spectrogram images and audio segments.
|
||||
|
||||
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
|
||||
to images and back. The real audio processing lives in SpectrogramConverter.
|
||||
"""
|
||||
|
||||
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
self.p = params
|
||||
self.device = device
|
||||
self.converter = SpectrogramConverter(params=params, device=device)
|
||||
|
||||
def spectrogram_image_from_audio(
|
||||
self,
|
||||
segment: pydub.AudioSegment,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Compute a spectrogram image from an audio segment.
|
||||
|
||||
Args:
|
||||
segment: Audio segment to convert
|
||||
|
||||
Returns:
|
||||
Spectrogram image (in pillow format)
|
||||
"""
|
||||
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
|
||||
|
||||
if self.p.stereo:
|
||||
if segment.channels == 1:
|
||||
print("WARNING: Mono audio but stereo=True, cloning channel")
|
||||
segment = segment.set_channels(2)
|
||||
elif segment.channels > 2:
|
||||
print("WARNING: Multi channel audio, reducing to stereo")
|
||||
segment = segment.set_channels(2)
|
||||
else:
|
||||
if segment.channels > 1:
|
||||
print("WARNING: Stereo audio but stereo=False, setting to mono")
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
spectrogram = self.converter.spectrogram_from_audio(segment)
|
||||
|
||||
image = image_util.image_from_spectrogram(
|
||||
spectrogram,
|
||||
power=self.p.power_for_image,
|
||||
)
|
||||
|
||||
# Store conversion params in exif metadata of the image
|
||||
exif_data = self.p.to_exif()
|
||||
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
|
||||
exif = image.getexif()
|
||||
exif.update(exif_data.items())
|
||||
|
||||
return image
|
||||
|
||||
def audio_from_spectrogram_image(
|
||||
self,
|
||||
image: Image.Image,
|
||||
apply_filters: bool = True,
|
||||
max_value: float = 30e6,
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Reconstruct an audio segment from a spectrogram image.
|
||||
|
||||
Args:
|
||||
image: Spectrogram image (in pillow format)
|
||||
apply_filters: Apply post-processing to improve the reconstructed audio
|
||||
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
|
||||
"""
|
||||
spectrogram = image_util.spectrogram_from_image(
|
||||
image,
|
||||
max_value=max_value,
|
||||
power=self.p.power_for_image,
|
||||
stereo=self.p.stereo,
|
||||
)
|
||||
|
||||
segment = self.converter.audio_from_spectrogram(
|
||||
spectrogram,
|
||||
apply_filters=apply_filters,
|
||||
)
|
||||
|
||||
return segment
|
|
@ -0,0 +1,97 @@
|
|||
import dataclasses
|
||||
import typing as T
|
||||
|
||||
from PIL import Image
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import fft_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SpectrogramImageConverterTest(TestCase):
|
||||
"""
|
||||
Test going from audio to spectrogram images to audio, testing the quality loss of the
|
||||
end-to-end pipeline.
|
||||
|
||||
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
||||
and by plotting their FFTs.
|
||||
|
||||
See spectrogram_converter_test.py for a similar test that does not convert to images.
|
||||
"""
|
||||
|
||||
def test_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_")
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono if desired
|
||||
use_stereo = False
|
||||
if use_stereo:
|
||||
assert segment.channels == 2
|
||||
else:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Define named sets of parameters
|
||||
param_sets: T.Dict[str, SpectrogramParams] = {}
|
||||
|
||||
param_sets["default"] = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=use_stereo,
|
||||
step_size_ms=10,
|
||||
min_frequency=20,
|
||||
max_frequency=20000,
|
||||
num_frequencies=512,
|
||||
)
|
||||
|
||||
if self.DEBUG:
|
||||
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
||||
param_sets["default"],
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segments: T.Dict[str, pydub.AudioSegment] = {
|
||||
"original": segment,
|
||||
}
|
||||
images: T.Dict[str, Image.Image] = {}
|
||||
for name, params in param_sets.items():
|
||||
converter = SpectrogramImageConverter(params=params, device=self.DEVICE)
|
||||
images[name] = converter.spectrogram_image_from_audio(segment)
|
||||
segments[name] = converter.audio_from_spectrogram_image(
|
||||
image=images[name],
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
# Save images to disk
|
||||
for name, image in images.items():
|
||||
image_out = output_dir / f"{name}.png"
|
||||
image.save(image_out, exif=image.getexif(), format="PNG")
|
||||
print(f"Saved {image_out}")
|
||||
|
||||
# Save segments to disk
|
||||
for name, segment in segments.items():
|
||||
audio_out = output_dir / f"{name}.wav"
|
||||
segment.export(audio_out, format="wav")
|
||||
print(f"Saved {audio_out}")
|
||||
|
||||
# Check params
|
||||
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
||||
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
||||
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
||||
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
||||
|
||||
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
||||
|
||||
# If debugging, load up a browser tab plotting the FFTs
|
||||
if self.DEBUG:
|
||||
fft_util.plot_ffts(segments)
|
Loading…
Reference in New Issue