100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
import typing as T
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from riffusion.cli import audio_to_image
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
|
|
from .test_case import TestCase
|
|
|
|
|
|
class AudioToImageTest(TestCase):
|
|
"""
|
|
Test riffusion.cli audio-to-image
|
|
"""
|
|
|
|
@classmethod
|
|
def default_params(cls) -> T.Dict:
|
|
return dict(
|
|
step_size_ms=10,
|
|
num_frequencies=512,
|
|
# TODO(hayk): Change these to [20, 20000] once a model is updated
|
|
min_frequency=0,
|
|
max_frequency=10000,
|
|
stereo=False,
|
|
device=cls.DEVICE,
|
|
)
|
|
|
|
def test_audio_to_image(self) -> None:
|
|
"""
|
|
Test audio-to-image with default params.
|
|
"""
|
|
params = self.default_params()
|
|
self.helper_test_with_params(params)
|
|
|
|
def test_stereo(self) -> None:
|
|
"""
|
|
Test audio-to-image with stereo=True.
|
|
"""
|
|
params = self.default_params()
|
|
params["stereo"] = True
|
|
self.helper_test_with_params(params)
|
|
|
|
def helper_test_with_params(self, params: T.Dict) -> None:
|
|
audio_path = (
|
|
self.TEST_DATA_PATH
|
|
/ "tired_traveler"
|
|
/ "clips"
|
|
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
|
)
|
|
output_dir = self.get_tmp_dir("audio_to_image_")
|
|
|
|
if params["stereo"]:
|
|
stem = f"{audio_path.stem}_stereo"
|
|
else:
|
|
stem = audio_path.stem
|
|
|
|
image_path = output_dir / f"{stem}.png"
|
|
|
|
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
|
|
|
|
# Check that the image exists
|
|
self.assertTrue(image_path.exists())
|
|
|
|
pil_image = Image.open(image_path)
|
|
|
|
# Check the image mode
|
|
self.assertEqual(pil_image.mode, "RGB")
|
|
|
|
# Check the image dimensions
|
|
duration_ms = 5678
|
|
self.assertTrue(str(duration_ms) in audio_path.name)
|
|
expected_image_width = round(duration_ms / params["step_size_ms"])
|
|
self.assertEqual(pil_image.width, expected_image_width)
|
|
self.assertEqual(pil_image.height, params["num_frequencies"])
|
|
|
|
# Get channels as numpy arrays
|
|
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
|
|
self.assertEqual(len(channels), 3)
|
|
|
|
if params["stereo"]:
|
|
# Check that the first channel is zero
|
|
self.assertTrue(np.all(channels[0] == 0))
|
|
else:
|
|
# Check that all channels are the same
|
|
self.assertTrue(np.all(channels[0] == channels[1]))
|
|
self.assertTrue(np.all(channels[0] == channels[2]))
|
|
|
|
# Check that the image has exif data
|
|
exif = pil_image.getexif()
|
|
self.assertIsNotNone(exif)
|
|
params_from_exif = SpectrogramParams.from_exif(exif)
|
|
expected_params = SpectrogramParams(
|
|
stereo=params["stereo"],
|
|
step_size_ms=params["step_size_ms"],
|
|
num_frequencies=params["num_frequencies"],
|
|
max_frequency=params["max_frequency"],
|
|
)
|
|
self.assertTrue(params_from_exif == expected_params)
|