Command line interface for common operations, plus tests

Add riffusion.cli tool for common operations. Add a test for
each one.

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:28:09 -08:00
parent 52bec9575b
commit dc4e2d8d64
5 changed files with 431 additions and 0 deletions

141
riffusion/cli.py Normal file
View File

@ -0,0 +1,141 @@
"""
Command line tools for riffusion.
"""
from pathlib import Path
import argh
import numpy as np
from PIL import Image
import pydub
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
def audio_to_image(
*,
audio: str,
image: str,
step_size_ms: int = 10,
num_frequencies: int = 512,
min_frequency: int = 0,
max_frequency: int = 10000,
window_duration_ms: int = 100,
padded_duration_ms: int = 400,
power_for_image: float = 0.25,
stereo: bool = False,
device: str = "cuda",
):
"""
Compute a spectrogram image from a waveform.
"""
segment = pydub.AudioSegment.from_file(audio)
params = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=stereo,
window_duration_ms=window_duration_ms,
padded_duration_ms=padded_duration_ms,
step_size_ms=step_size_ms,
min_frequency=min_frequency,
max_frequency=max_frequency,
num_frequencies=num_frequencies,
power_for_image=power_for_image,
)
converter = SpectrogramImageConverter(params=params, device=device)
pil_image = converter.spectrogram_image_from_audio(segment)
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
print(f"Wrote {image}")
def print_exif(*, image: str) -> None:
"""
Print the params of a spectrogram image as saved in the exif data.
"""
pil_image = Image.open(image)
exif_data = image_util.exif_from_image(pil_image)
for name, value in exif_data.items():
print(f"{name:<20} = {value:>15}")
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
"""
Reconstruct an audio clip from a spectrogram image.
"""
pil_image = Image.open(image)
# Get parameters from image exif
img_exif = pil_image.getexif()
assert img_exif is not None
try:
params = SpectrogramParams.from_exif(exif=img_exif)
except KeyError:
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
converter = SpectrogramImageConverter(params=params, device=device)
segment = converter.audio_from_spectrogram_image(pil_image)
extension = Path(audio).suffix[1:]
segment.export(audio, format=extension)
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
def sample_clips(
*,
audio: str,
output_dir: str,
num_clips: int = 1,
duration_ms: int = 5000,
mono: bool = False,
extension: str = "wav",
seed: int = -1,
):
"""
Slice an audio file into clips of the given duration.
"""
if seed >= 0:
np.random.seed(seed)
segment = pydub.AudioSegment.from_file(audio)
if mono:
segment = segment.set_channels(1)
output_dir_path = Path(output_dir)
if not output_dir_path.exists():
output_dir_path.mkdir(parents=True)
# TODO(hayk): Might be a lot easier with pydub
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
clip_path = output_dir_path / clip_name
clip.export(clip_path, format=extension)
print(f"Wrote {clip_path}")
if __name__ == "__main__":
argh.dispatch_commands(
[
audio_to_image,
image_to_audio,
sample_clips,
print_exif,
]
)

View File

@ -0,0 +1,99 @@
import typing as T
import numpy as np
from PIL import Image
from riffusion.cli import audio_to_image
from riffusion.spectrogram_params import SpectrogramParams
from .test_case import TestCase
class AudioToImageTest(TestCase):
"""
Test riffusion.cli audio-to-image
"""
@classmethod
def default_params(cls) -> T.Dict:
return dict(
step_size_ms=10,
num_frequencies=512,
# TODO(hayk): Change these to [20, 20000] once a model is updated
min_frequency=0,
max_frequency=10000,
stereo=False,
device=cls.DEVICE,
)
def test_audio_to_image(self) -> None:
"""
Test audio-to-image with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_stereo(self) -> None:
"""
Test audio-to-image with stereo=True.
"""
params = self.default_params()
params["stereo"] = True
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir("audio_to_image_")
if params["stereo"]:
stem = f"{audio_path.stem}_stereo"
else:
stem = audio_path.stem
image_path = output_dir / f"{stem}.png"
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
# Check that the image exists
self.assertTrue(image_path.exists())
pil_image = Image.open(image_path)
# Check the image mode
self.assertEqual(pil_image.mode, "RGB")
# Check the image dimensions
duration_ms = 5678
self.assertTrue(str(duration_ms) in audio_path.name)
expected_image_width = round(duration_ms / params["step_size_ms"])
self.assertEqual(pil_image.width, expected_image_width)
self.assertEqual(pil_image.height, params["num_frequencies"])
# Get channels as numpy arrays
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
self.assertEqual(len(channels), 3)
if params["stereo"]:
# Check that the first channel is zero
self.assertTrue(np.all(channels[0] == 0))
else:
# Check that all channels are the same
self.assertTrue(np.all(channels[0] == channels[1]))
self.assertTrue(np.all(channels[0] == channels[2]))
# Check that the image has exif data
exif = pil_image.getexif()
self.assertIsNotNone(exif)
params_from_exif = SpectrogramParams.from_exif(exif)
expected_params = SpectrogramParams(
stereo=params["stereo"],
step_size_ms=params["step_size_ms"],
num_frequencies=params["num_frequencies"],
max_frequency=params["max_frequency"],
)
self.assertTrue(params_from_exif == expected_params)

View File

@ -0,0 +1,71 @@
from pathlib import Path
import pydub
from riffusion.cli import image_to_audio
from .test_case import TestCase
class ImageToAudioTest(TestCase):
"""
Test riffusion.cli image-to-audio
"""
def test_image_to_audio_mono(self) -> None:
self.helper_image_to_audio(
song_dir=self.TEST_DATA_PATH / "tired_traveler",
clip_name="clip_2_start_103694_ms_duration_5678_ms",
stereo=False,
)
def test_image_to_audio_stereo(self) -> None:
self.helper_image_to_audio(
song_dir=self.TEST_DATA_PATH / "tired_traveler",
clip_name="clip_2_start_103694_ms_duration_5678_ms",
stereo=True,
)
def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
if stereo:
image_stem = clip_name + "_stereo"
else:
image_stem = clip_name
image_path = song_dir / "images" / f"{image_stem}.png"
output_dir = self.get_tmp_dir("image_to_audio_")
audio_path = output_dir / f"{image_path.stem}.wav"
image_to_audio(
image=str(image_path),
audio=str(audio_path),
device=self.DEVICE,
)
# Check that the audio exists
self.assertTrue(audio_path.exists())
# Load the reconstructed audio and the original clip
segment = pydub.AudioSegment.from_file(str(audio_path))
expected_segment = pydub.AudioSegment.from_file(
str(song_dir / "clips" / f"{clip_name}.wav")
)
# Check sample rate
self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
# Check duration
actual_duration_ms = round(segment.duration_seconds * 1000)
expected_duration_ms = round(expected_segment.duration_seconds * 1000)
self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
# Check the number of channels
self.assertEqual(expected_segment.channels, 2)
if stereo:
self.assertEqual(segment.channels, 2)
else:
self.assertEqual(segment.channels, 1)
if __name__ == "__main__":
TestCase.main()

32
test/print_exif_test.py Normal file
View File

@ -0,0 +1,32 @@
import contextlib
import io
from riffusion.cli import print_exif
from .test_case import TestCase
class PrintExifTest(TestCase):
"""
Test riffusion.cli print-exif
"""
def test_print_exif(self) -> None:
"""
Test print-exif.
"""
image_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "images"
/ "clip_2_start_103694_ms_duration_5678_ms.png"
)
# Redirect stdout
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
print_exif(image=str(image_path))
# Check that a couple of values are printed
self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue())
self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue())

88
test/sample_clips_test.py Normal file
View File

@ -0,0 +1,88 @@
import typing as T
import pydub
from riffusion.cli import sample_clips
from .test_case import TestCase
class SampleClipsTest(TestCase):
"""
Test riffusion.cli sample-clips
"""
@staticmethod
def default_params() -> T.Dict:
return dict(
num_clips=3,
duration_ms=5678,
mono=False,
extension="wav",
seed=42,
)
def test_sample_clips(self) -> None:
"""
Test sample-clips with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_mono(self) -> None:
"""
Test sample-clips with mono=True.
"""
params = self.default_params()
params["mono"] = True
params["num_clips"] = 1
self.helper_test_with_params(params)
def test_mp3(self) -> None:
"""
Test sample-clips with extension=mp3.
"""
if pydub.AudioSegment.converter is None:
self.skipTest("skipping, ffmpeg not found")
params = self.default_params()
params["extension"] = "mp3"
params["num_clips"] = 1
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
"""
Test sample-clips with the given params.
"""
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
output_dir = self.get_tmp_dir("sample_clips_")
sample_clips(
audio=str(audio_path),
output_dir=str(output_dir),
**params,
)
# For each file in output dir
counter = 0
for clip_path in output_dir.iterdir():
# Check that it has the right extension
self.assertEqual(clip_path.suffix, f".{params['extension']}")
# Check that it has the right duration
segment = pydub.AudioSegment.from_file(clip_path)
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
# Check that it has the right number of channels
if params["mono"]:
self.assertEqual(segment.channels, 1)
else:
self.assertEqual(segment.channels, 2)
counter += 1
self.assertEqual(counter, params["num_clips"])
if __name__ == "__main__":
TestCase.main()