diff --git a/riffusion/cli.py b/riffusion/cli.py new file mode 100644 index 0000000..2022c05 --- /dev/null +++ b/riffusion/cli.py @@ -0,0 +1,141 @@ +""" +Command line tools for riffusion. +""" + +from pathlib import Path + +import argh +import numpy as np +from PIL import Image +import pydub + +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.util import image_util + + +@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image") +@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image") +def audio_to_image( + *, + audio: str, + image: str, + step_size_ms: int = 10, + num_frequencies: int = 512, + min_frequency: int = 0, + max_frequency: int = 10000, + window_duration_ms: int = 100, + padded_duration_ms: int = 400, + power_for_image: float = 0.25, + stereo: bool = False, + device: str = "cuda", +): + """ + Compute a spectrogram image from a waveform. + """ + segment = pydub.AudioSegment.from_file(audio) + + params = SpectrogramParams( + sample_rate=segment.frame_rate, + stereo=stereo, + window_duration_ms=window_duration_ms, + padded_duration_ms=padded_duration_ms, + step_size_ms=step_size_ms, + min_frequency=min_frequency, + max_frequency=max_frequency, + num_frequencies=num_frequencies, + power_for_image=power_for_image, + ) + + converter = SpectrogramImageConverter(params=params, device=device) + + pil_image = converter.spectrogram_image_from_audio(segment) + + pil_image.save(image, exif=pil_image.getexif(), format="PNG") + print(f"Wrote {image}") + + +def print_exif(*, image: str) -> None: + """ + Print the params of a spectrogram image as saved in the exif data. + """ + pil_image = Image.open(image) + exif_data = image_util.exif_from_image(pil_image) + + for name, value in exif_data.items(): + print(f"{name:<20} = {value:>15}") + + +def image_to_audio(*, image: str, audio: str, device: str = "cuda"): + """ + Reconstruct an audio clip from a spectrogram image. + """ + pil_image = Image.open(image) + + # Get parameters from image exif + img_exif = pil_image.getexif() + assert img_exif is not None + + try: + params = SpectrogramParams.from_exif(exif=img_exif) + except KeyError: + print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.") + params = SpectrogramParams() + + converter = SpectrogramImageConverter(params=params, device=device) + segment = converter.audio_from_spectrogram_image(pil_image) + + extension = Path(audio).suffix[1:] + segment.export(audio, format=extension) + + print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)") + + +def sample_clips( + *, + audio: str, + output_dir: str, + num_clips: int = 1, + duration_ms: int = 5000, + mono: bool = False, + extension: str = "wav", + seed: int = -1, +): + """ + Slice an audio file into clips of the given duration. + """ + if seed >= 0: + np.random.seed(seed) + + segment = pydub.AudioSegment.from_file(audio) + + if mono: + segment = segment.set_channels(1) + + output_dir_path = Path(output_dir) + if not output_dir_path.exists(): + output_dir_path.mkdir(parents=True) + + # TODO(hayk): Might be a lot easier with pydub + # https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file + + segment_duration_ms = int(segment.duration_seconds * 1000) + for i in range(num_clips): + clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) + clip = segment[clip_start_ms : clip_start_ms + duration_ms] + + clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}" + clip_path = output_dir_path / clip_name + clip.export(clip_path, format=extension) + print(f"Wrote {clip_path}") + + +if __name__ == "__main__": + argh.dispatch_commands( + [ + audio_to_image, + image_to_audio, + sample_clips, + print_exif, + ] + ) diff --git a/test/audio_to_image_test.py b/test/audio_to_image_test.py new file mode 100644 index 0000000..4a84d76 --- /dev/null +++ b/test/audio_to_image_test.py @@ -0,0 +1,99 @@ +import typing as T + +import numpy as np +from PIL import Image + +from riffusion.cli import audio_to_image +from riffusion.spectrogram_params import SpectrogramParams + +from .test_case import TestCase + + +class AudioToImageTest(TestCase): + """ + Test riffusion.cli audio-to-image + """ + + @classmethod + def default_params(cls) -> T.Dict: + return dict( + step_size_ms=10, + num_frequencies=512, + # TODO(hayk): Change these to [20, 20000] once a model is updated + min_frequency=0, + max_frequency=10000, + stereo=False, + device=cls.DEVICE, + ) + + def test_audio_to_image(self) -> None: + """ + Test audio-to-image with default params. + """ + params = self.default_params() + self.helper_test_with_params(params) + + def test_stereo(self) -> None: + """ + Test audio-to-image with stereo=True. + """ + params = self.default_params() + params["stereo"] = True + self.helper_test_with_params(params) + + def helper_test_with_params(self, params: T.Dict) -> None: + audio_path = ( + self.TEST_DATA_PATH + / "tired_traveler" + / "clips" + / "clip_2_start_103694_ms_duration_5678_ms.wav" + ) + output_dir = self.get_tmp_dir("audio_to_image_") + + if params["stereo"]: + stem = f"{audio_path.stem}_stereo" + else: + stem = audio_path.stem + + image_path = output_dir / f"{stem}.png" + + audio_to_image(audio=str(audio_path), image=str(image_path), **params) + + # Check that the image exists + self.assertTrue(image_path.exists()) + + pil_image = Image.open(image_path) + + # Check the image mode + self.assertEqual(pil_image.mode, "RGB") + + # Check the image dimensions + duration_ms = 5678 + self.assertTrue(str(duration_ms) in audio_path.name) + expected_image_width = round(duration_ms / params["step_size_ms"]) + self.assertEqual(pil_image.width, expected_image_width) + self.assertEqual(pil_image.height, params["num_frequencies"]) + + # Get channels as numpy arrays + channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))] + self.assertEqual(len(channels), 3) + + if params["stereo"]: + # Check that the first channel is zero + self.assertTrue(np.all(channels[0] == 0)) + else: + # Check that all channels are the same + self.assertTrue(np.all(channels[0] == channels[1])) + self.assertTrue(np.all(channels[0] == channels[2])) + + # Check that the image has exif data + exif = pil_image.getexif() + self.assertIsNotNone(exif) + params_from_exif = SpectrogramParams.from_exif(exif) + expected_params = SpectrogramParams( + stereo=params["stereo"], + step_size_ms=params["step_size_ms"], + num_frequencies=params["num_frequencies"], + max_frequency=params["max_frequency"], + ) + self.assertTrue(params_from_exif == expected_params) diff --git a/test/image_to_audio_test.py b/test/image_to_audio_test.py new file mode 100644 index 0000000..7ae1424 --- /dev/null +++ b/test/image_to_audio_test.py @@ -0,0 +1,71 @@ +from pathlib import Path + +import pydub + +from riffusion.cli import image_to_audio + +from .test_case import TestCase + + +class ImageToAudioTest(TestCase): + """ + Test riffusion.cli image-to-audio + """ + + def test_image_to_audio_mono(self) -> None: + self.helper_image_to_audio( + song_dir=self.TEST_DATA_PATH / "tired_traveler", + clip_name="clip_2_start_103694_ms_duration_5678_ms", + stereo=False, + ) + + def test_image_to_audio_stereo(self) -> None: + self.helper_image_to_audio( + song_dir=self.TEST_DATA_PATH / "tired_traveler", + clip_name="clip_2_start_103694_ms_duration_5678_ms", + stereo=True, + ) + + def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None: + if stereo: + image_stem = clip_name + "_stereo" + else: + image_stem = clip_name + + image_path = song_dir / "images" / f"{image_stem}.png" + output_dir = self.get_tmp_dir("image_to_audio_") + audio_path = output_dir / f"{image_path.stem}.wav" + + image_to_audio( + image=str(image_path), + audio=str(audio_path), + device=self.DEVICE, + ) + + # Check that the audio exists + self.assertTrue(audio_path.exists()) + + # Load the reconstructed audio and the original clip + segment = pydub.AudioSegment.from_file(str(audio_path)) + expected_segment = pydub.AudioSegment.from_file( + str(song_dir / "clips" / f"{clip_name}.wav") + ) + + # Check sample rate + self.assertEqual(segment.frame_rate, expected_segment.frame_rate) + + # Check duration + actual_duration_ms = round(segment.duration_seconds * 1000) + expected_duration_ms = round(expected_segment.duration_seconds * 1000) + self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10) + + # Check the number of channels + self.assertEqual(expected_segment.channels, 2) + if stereo: + self.assertEqual(segment.channels, 2) + else: + self.assertEqual(segment.channels, 1) + + +if __name__ == "__main__": + TestCase.main() diff --git a/test/print_exif_test.py b/test/print_exif_test.py new file mode 100644 index 0000000..82b6d94 --- /dev/null +++ b/test/print_exif_test.py @@ -0,0 +1,32 @@ +import contextlib +import io + +from riffusion.cli import print_exif + +from .test_case import TestCase + + +class PrintExifTest(TestCase): + """ + Test riffusion.cli print-exif + """ + + def test_print_exif(self) -> None: + """ + Test print-exif. + """ + image_path = ( + self.TEST_DATA_PATH + / "tired_traveler" + / "images" + / "clip_2_start_103694_ms_duration_5678_ms.png" + ) + + # Redirect stdout + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout): + print_exif(image=str(image_path)) + + # Check that a couple of values are printed + self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue()) + self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue()) diff --git a/test/sample_clips_test.py b/test/sample_clips_test.py new file mode 100644 index 0000000..4b3080a --- /dev/null +++ b/test/sample_clips_test.py @@ -0,0 +1,88 @@ +import typing as T + +import pydub + +from riffusion.cli import sample_clips + +from .test_case import TestCase + + +class SampleClipsTest(TestCase): + """ + Test riffusion.cli sample-clips + """ + + @staticmethod + def default_params() -> T.Dict: + return dict( + num_clips=3, + duration_ms=5678, + mono=False, + extension="wav", + seed=42, + ) + + def test_sample_clips(self) -> None: + """ + Test sample-clips with default params. + """ + params = self.default_params() + self.helper_test_with_params(params) + + def test_mono(self) -> None: + """ + Test sample-clips with mono=True. + """ + params = self.default_params() + params["mono"] = True + params["num_clips"] = 1 + self.helper_test_with_params(params) + + def test_mp3(self) -> None: + """ + Test sample-clips with extension=mp3. + """ + if pydub.AudioSegment.converter is None: + self.skipTest("skipping, ffmpeg not found") + + params = self.default_params() + params["extension"] = "mp3" + params["num_clips"] = 1 + self.helper_test_with_params(params) + + def helper_test_with_params(self, params: T.Dict) -> None: + """ + Test sample-clips with the given params. + """ + audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3" + output_dir = self.get_tmp_dir("sample_clips_") + + sample_clips( + audio=str(audio_path), + output_dir=str(output_dir), + **params, + ) + + # For each file in output dir + counter = 0 + for clip_path in output_dir.iterdir(): + # Check that it has the right extension + self.assertEqual(clip_path.suffix, f".{params['extension']}") + + # Check that it has the right duration + segment = pydub.AudioSegment.from_file(clip_path) + self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"]) + + # Check that it has the right number of channels + if params["mono"]: + self.assertEqual(segment.channels, 1) + else: + self.assertEqual(segment.channels, 2) + + counter += 1 + + self.assertEqual(counter, params["num_clips"]) + + +if __name__ == "__main__": + TestCase.main()