riffusion-inference/riffusion/cli.py

142 lines
3.9 KiB
Python
Raw Normal View History

"""
Command line tools for riffusion.
"""
from pathlib import Path
import argh
import numpy as np
from PIL import Image
import pydub
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
def audio_to_image(
*,
audio: str,
image: str,
step_size_ms: int = 10,
num_frequencies: int = 512,
min_frequency: int = 0,
max_frequency: int = 10000,
window_duration_ms: int = 100,
padded_duration_ms: int = 400,
power_for_image: float = 0.25,
stereo: bool = False,
device: str = "cuda",
):
"""
Compute a spectrogram image from a waveform.
"""
segment = pydub.AudioSegment.from_file(audio)
params = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=stereo,
window_duration_ms=window_duration_ms,
padded_duration_ms=padded_duration_ms,
step_size_ms=step_size_ms,
min_frequency=min_frequency,
max_frequency=max_frequency,
num_frequencies=num_frequencies,
power_for_image=power_for_image,
)
converter = SpectrogramImageConverter(params=params, device=device)
pil_image = converter.spectrogram_image_from_audio(segment)
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
print(f"Wrote {image}")
def print_exif(*, image: str) -> None:
"""
Print the params of a spectrogram image as saved in the exif data.
"""
pil_image = Image.open(image)
exif_data = image_util.exif_from_image(pil_image)
for name, value in exif_data.items():
print(f"{name:<20} = {value:>15}")
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
"""
Reconstruct an audio clip from a spectrogram image.
"""
pil_image = Image.open(image)
# Get parameters from image exif
img_exif = pil_image.getexif()
assert img_exif is not None
try:
params = SpectrogramParams.from_exif(exif=img_exif)
except KeyError:
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
converter = SpectrogramImageConverter(params=params, device=device)
segment = converter.audio_from_spectrogram_image(pil_image)
extension = Path(audio).suffix[1:]
segment.export(audio, format=extension)
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
def sample_clips(
*,
audio: str,
output_dir: str,
num_clips: int = 1,
duration_ms: int = 5000,
mono: bool = False,
extension: str = "wav",
seed: int = -1,
):
"""
Slice an audio file into clips of the given duration.
"""
if seed >= 0:
np.random.seed(seed)
segment = pydub.AudioSegment.from_file(audio)
if mono:
segment = segment.set_channels(1)
output_dir_path = Path(output_dir)
if not output_dir_path.exists():
output_dir_path.mkdir(parents=True)
# TODO(hayk): Might be a lot easier with pydub
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
clip_path = output_dir_path / clip_name
clip.export(clip_path, format=extension)
print(f"Wrote {clip_path}")
if __name__ == "__main__":
argh.dispatch_commands(
[
audio_to_image,
image_to_audio,
sample_clips,
print_exif,
]
)