Command line interface for common operations, plus tests
Add riffusion.cli tool for common operations. Add a test for each one. Topic: clean_rewrite
This commit is contained in:
parent
52bec9575b
commit
dc4e2d8d64
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
Command line tools for riffusion.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import argh
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
|
||||
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
|
||||
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
|
||||
def audio_to_image(
|
||||
*,
|
||||
audio: str,
|
||||
image: str,
|
||||
step_size_ms: int = 10,
|
||||
num_frequencies: int = 512,
|
||||
min_frequency: int = 0,
|
||||
max_frequency: int = 10000,
|
||||
window_duration_ms: int = 100,
|
||||
padded_duration_ms: int = 400,
|
||||
power_for_image: float = 0.25,
|
||||
stereo: bool = False,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""
|
||||
Compute a spectrogram image from a waveform.
|
||||
"""
|
||||
segment = pydub.AudioSegment.from_file(audio)
|
||||
|
||||
params = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=stereo,
|
||||
window_duration_ms=window_duration_ms,
|
||||
padded_duration_ms=padded_duration_ms,
|
||||
step_size_ms=step_size_ms,
|
||||
min_frequency=min_frequency,
|
||||
max_frequency=max_frequency,
|
||||
num_frequencies=num_frequencies,
|
||||
power_for_image=power_for_image,
|
||||
)
|
||||
|
||||
converter = SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
pil_image = converter.spectrogram_image_from_audio(segment)
|
||||
|
||||
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
|
||||
print(f"Wrote {image}")
|
||||
|
||||
|
||||
def print_exif(*, image: str) -> None:
|
||||
"""
|
||||
Print the params of a spectrogram image as saved in the exif data.
|
||||
"""
|
||||
pil_image = Image.open(image)
|
||||
exif_data = image_util.exif_from_image(pil_image)
|
||||
|
||||
for name, value in exif_data.items():
|
||||
print(f"{name:<20} = {value:>15}")
|
||||
|
||||
|
||||
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
|
||||
"""
|
||||
Reconstruct an audio clip from a spectrogram image.
|
||||
"""
|
||||
pil_image = Image.open(image)
|
||||
|
||||
# Get parameters from image exif
|
||||
img_exif = pil_image.getexif()
|
||||
assert img_exif is not None
|
||||
|
||||
try:
|
||||
params = SpectrogramParams.from_exif(exif=img_exif)
|
||||
except KeyError:
|
||||
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
|
||||
params = SpectrogramParams()
|
||||
|
||||
converter = SpectrogramImageConverter(params=params, device=device)
|
||||
segment = converter.audio_from_spectrogram_image(pil_image)
|
||||
|
||||
extension = Path(audio).suffix[1:]
|
||||
segment.export(audio, format=extension)
|
||||
|
||||
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
|
||||
|
||||
|
||||
def sample_clips(
|
||||
*,
|
||||
audio: str,
|
||||
output_dir: str,
|
||||
num_clips: int = 1,
|
||||
duration_ms: int = 5000,
|
||||
mono: bool = False,
|
||||
extension: str = "wav",
|
||||
seed: int = -1,
|
||||
):
|
||||
"""
|
||||
Slice an audio file into clips of the given duration.
|
||||
"""
|
||||
if seed >= 0:
|
||||
np.random.seed(seed)
|
||||
|
||||
segment = pydub.AudioSegment.from_file(audio)
|
||||
|
||||
if mono:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
output_dir_path = Path(output_dir)
|
||||
if not output_dir_path.exists():
|
||||
output_dir_path.mkdir(parents=True)
|
||||
|
||||
# TODO(hayk): Might be a lot easier with pydub
|
||||
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
|
||||
|
||||
segment_duration_ms = int(segment.duration_seconds * 1000)
|
||||
for i in range(num_clips):
|
||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
||||
clip_path = output_dir_path / clip_name
|
||||
clip.export(clip_path, format=extension)
|
||||
print(f"Wrote {clip_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argh.dispatch_commands(
|
||||
[
|
||||
audio_to_image,
|
||||
image_to_audio,
|
||||
sample_clips,
|
||||
print_exif,
|
||||
]
|
||||
)
|
|
@ -0,0 +1,99 @@
|
|||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.cli import audio_to_image
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class AudioToImageTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli audio-to-image
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def default_params(cls) -> T.Dict:
|
||||
return dict(
|
||||
step_size_ms=10,
|
||||
num_frequencies=512,
|
||||
# TODO(hayk): Change these to [20, 20000] once a model is updated
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
stereo=False,
|
||||
device=cls.DEVICE,
|
||||
)
|
||||
|
||||
def test_audio_to_image(self) -> None:
|
||||
"""
|
||||
Test audio-to-image with default params.
|
||||
"""
|
||||
params = self.default_params()
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_stereo(self) -> None:
|
||||
"""
|
||||
Test audio-to-image with stereo=True.
|
||||
"""
|
||||
params = self.default_params()
|
||||
params["stereo"] = True
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def helper_test_with_params(self, params: T.Dict) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir("audio_to_image_")
|
||||
|
||||
if params["stereo"]:
|
||||
stem = f"{audio_path.stem}_stereo"
|
||||
else:
|
||||
stem = audio_path.stem
|
||||
|
||||
image_path = output_dir / f"{stem}.png"
|
||||
|
||||
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
|
||||
|
||||
# Check that the image exists
|
||||
self.assertTrue(image_path.exists())
|
||||
|
||||
pil_image = Image.open(image_path)
|
||||
|
||||
# Check the image mode
|
||||
self.assertEqual(pil_image.mode, "RGB")
|
||||
|
||||
# Check the image dimensions
|
||||
duration_ms = 5678
|
||||
self.assertTrue(str(duration_ms) in audio_path.name)
|
||||
expected_image_width = round(duration_ms / params["step_size_ms"])
|
||||
self.assertEqual(pil_image.width, expected_image_width)
|
||||
self.assertEqual(pil_image.height, params["num_frequencies"])
|
||||
|
||||
# Get channels as numpy arrays
|
||||
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
|
||||
self.assertEqual(len(channels), 3)
|
||||
|
||||
if params["stereo"]:
|
||||
# Check that the first channel is zero
|
||||
self.assertTrue(np.all(channels[0] == 0))
|
||||
else:
|
||||
# Check that all channels are the same
|
||||
self.assertTrue(np.all(channels[0] == channels[1]))
|
||||
self.assertTrue(np.all(channels[0] == channels[2]))
|
||||
|
||||
# Check that the image has exif data
|
||||
exif = pil_image.getexif()
|
||||
self.assertIsNotNone(exif)
|
||||
params_from_exif = SpectrogramParams.from_exif(exif)
|
||||
expected_params = SpectrogramParams(
|
||||
stereo=params["stereo"],
|
||||
step_size_ms=params["step_size_ms"],
|
||||
num_frequencies=params["num_frequencies"],
|
||||
max_frequency=params["max_frequency"],
|
||||
)
|
||||
self.assertTrue(params_from_exif == expected_params)
|
|
@ -0,0 +1,71 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.cli import image_to_audio
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class ImageToAudioTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli image-to-audio
|
||||
"""
|
||||
|
||||
def test_image_to_audio_mono(self) -> None:
|
||||
self.helper_image_to_audio(
|
||||
song_dir=self.TEST_DATA_PATH / "tired_traveler",
|
||||
clip_name="clip_2_start_103694_ms_duration_5678_ms",
|
||||
stereo=False,
|
||||
)
|
||||
|
||||
def test_image_to_audio_stereo(self) -> None:
|
||||
self.helper_image_to_audio(
|
||||
song_dir=self.TEST_DATA_PATH / "tired_traveler",
|
||||
clip_name="clip_2_start_103694_ms_duration_5678_ms",
|
||||
stereo=True,
|
||||
)
|
||||
|
||||
def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
|
||||
if stereo:
|
||||
image_stem = clip_name + "_stereo"
|
||||
else:
|
||||
image_stem = clip_name
|
||||
|
||||
image_path = song_dir / "images" / f"{image_stem}.png"
|
||||
output_dir = self.get_tmp_dir("image_to_audio_")
|
||||
audio_path = output_dir / f"{image_path.stem}.wav"
|
||||
|
||||
image_to_audio(
|
||||
image=str(image_path),
|
||||
audio=str(audio_path),
|
||||
device=self.DEVICE,
|
||||
)
|
||||
|
||||
# Check that the audio exists
|
||||
self.assertTrue(audio_path.exists())
|
||||
|
||||
# Load the reconstructed audio and the original clip
|
||||
segment = pydub.AudioSegment.from_file(str(audio_path))
|
||||
expected_segment = pydub.AudioSegment.from_file(
|
||||
str(song_dir / "clips" / f"{clip_name}.wav")
|
||||
)
|
||||
|
||||
# Check sample rate
|
||||
self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
|
||||
|
||||
# Check duration
|
||||
actual_duration_ms = round(segment.duration_seconds * 1000)
|
||||
expected_duration_ms = round(expected_segment.duration_seconds * 1000)
|
||||
self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
|
||||
|
||||
# Check the number of channels
|
||||
self.assertEqual(expected_segment.channels, 2)
|
||||
if stereo:
|
||||
self.assertEqual(segment.channels, 2)
|
||||
else:
|
||||
self.assertEqual(segment.channels, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase.main()
|
|
@ -0,0 +1,32 @@
|
|||
import contextlib
|
||||
import io
|
||||
|
||||
from riffusion.cli import print_exif
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class PrintExifTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli print-exif
|
||||
"""
|
||||
|
||||
def test_print_exif(self) -> None:
|
||||
"""
|
||||
Test print-exif.
|
||||
"""
|
||||
image_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "images"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.png"
|
||||
)
|
||||
|
||||
# Redirect stdout
|
||||
stdout = io.StringIO()
|
||||
with contextlib.redirect_stdout(stdout):
|
||||
print_exif(image=str(image_path))
|
||||
|
||||
# Check that a couple of values are printed
|
||||
self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue())
|
||||
self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue())
|
|
@ -0,0 +1,88 @@
|
|||
import typing as T
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.cli import sample_clips
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SampleClipsTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli sample-clips
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def default_params() -> T.Dict:
|
||||
return dict(
|
||||
num_clips=3,
|
||||
duration_ms=5678,
|
||||
mono=False,
|
||||
extension="wav",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
def test_sample_clips(self) -> None:
|
||||
"""
|
||||
Test sample-clips with default params.
|
||||
"""
|
||||
params = self.default_params()
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_mono(self) -> None:
|
||||
"""
|
||||
Test sample-clips with mono=True.
|
||||
"""
|
||||
params = self.default_params()
|
||||
params["mono"] = True
|
||||
params["num_clips"] = 1
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_mp3(self) -> None:
|
||||
"""
|
||||
Test sample-clips with extension=mp3.
|
||||
"""
|
||||
if pydub.AudioSegment.converter is None:
|
||||
self.skipTest("skipping, ffmpeg not found")
|
||||
|
||||
params = self.default_params()
|
||||
params["extension"] = "mp3"
|
||||
params["num_clips"] = 1
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def helper_test_with_params(self, params: T.Dict) -> None:
|
||||
"""
|
||||
Test sample-clips with the given params.
|
||||
"""
|
||||
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
|
||||
output_dir = self.get_tmp_dir("sample_clips_")
|
||||
|
||||
sample_clips(
|
||||
audio=str(audio_path),
|
||||
output_dir=str(output_dir),
|
||||
**params,
|
||||
)
|
||||
|
||||
# For each file in output dir
|
||||
counter = 0
|
||||
for clip_path in output_dir.iterdir():
|
||||
# Check that it has the right extension
|
||||
self.assertEqual(clip_path.suffix, f".{params['extension']}")
|
||||
|
||||
# Check that it has the right duration
|
||||
segment = pydub.AudioSegment.from_file(clip_path)
|
||||
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
|
||||
|
||||
# Check that it has the right number of channels
|
||||
if params["mono"]:
|
||||
self.assertEqual(segment.channels, 1)
|
||||
else:
|
||||
self.assertEqual(segment.channels, 2)
|
||||
|
||||
counter += 1
|
||||
|
||||
self.assertEqual(counter, params["num_clips"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase.main()
|
Loading…
Reference in New Issue