riffusion-inference/test/audio_to_image_test.py

100 lines
3.1 KiB
Python
Raw Normal View History

import typing as T
import numpy as np
from PIL import Image
from riffusion.cli import audio_to_image
from riffusion.spectrogram_params import SpectrogramParams
from .test_case import TestCase
class AudioToImageTest(TestCase):
"""
Test riffusion.cli audio-to-image
"""
@classmethod
def default_params(cls) -> T.Dict:
return dict(
step_size_ms=10,
num_frequencies=512,
# TODO(hayk): Change these to [20, 20000] once a model is updated
min_frequency=0,
max_frequency=10000,
stereo=False,
device=cls.DEVICE,
)
def test_audio_to_image(self) -> None:
"""
Test audio-to-image with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_stereo(self) -> None:
"""
Test audio-to-image with stereo=True.
"""
params = self.default_params()
params["stereo"] = True
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir("audio_to_image_")
if params["stereo"]:
stem = f"{audio_path.stem}_stereo"
else:
stem = audio_path.stem
image_path = output_dir / f"{stem}.png"
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
# Check that the image exists
self.assertTrue(image_path.exists())
pil_image = Image.open(image_path)
# Check the image mode
self.assertEqual(pil_image.mode, "RGB")
# Check the image dimensions
duration_ms = 5678
self.assertTrue(str(duration_ms) in audio_path.name)
expected_image_width = round(duration_ms / params["step_size_ms"])
self.assertEqual(pil_image.width, expected_image_width)
self.assertEqual(pil_image.height, params["num_frequencies"])
# Get channels as numpy arrays
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
self.assertEqual(len(channels), 3)
if params["stereo"]:
# Check that the first channel is zero
self.assertTrue(np.all(channels[0] == 0))
else:
# Check that all channels are the same
self.assertTrue(np.all(channels[0] == channels[1]))
self.assertTrue(np.all(channels[0] == channels[2]))
# Check that the image has exif data
exif = pil_image.getexif()
self.assertIsNotNone(exif)
params_from_exif = SpectrogramParams.from_exif(exif)
expected_params = SpectrogramParams(
stereo=params["stereo"],
step_size_ms=params["step_size_ms"],
num_frequencies=params["num_frequencies"],
max_frequency=params["max_frequency"],
)
self.assertTrue(params_from_exif == expected_params)