2022-12-26 18:28:09 -07:00
|
|
|
"""
|
|
|
|
Command line tools for riffusion.
|
|
|
|
"""
|
|
|
|
|
2023-01-28 11:46:04 -07:00
|
|
|
import random
|
|
|
|
import typing as T
|
|
|
|
from multiprocessing.pool import ThreadPool
|
2022-12-26 18:28:09 -07:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import argh
|
|
|
|
import numpy as np
|
|
|
|
import pydub
|
2023-01-28 11:46:04 -07:00
|
|
|
import tqdm
|
2022-12-26 19:12:02 -07:00
|
|
|
from PIL import Image
|
2022-12-26 18:28:09 -07:00
|
|
|
|
|
|
|
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
|
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
|
|
from riffusion.util import image_util
|
|
|
|
|
|
|
|
|
|
|
|
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
|
|
|
|
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
|
|
|
|
def audio_to_image(
|
|
|
|
*,
|
|
|
|
audio: str,
|
|
|
|
image: str,
|
|
|
|
step_size_ms: int = 10,
|
|
|
|
num_frequencies: int = 512,
|
|
|
|
min_frequency: int = 0,
|
|
|
|
max_frequency: int = 10000,
|
|
|
|
window_duration_ms: int = 100,
|
|
|
|
padded_duration_ms: int = 400,
|
|
|
|
power_for_image: float = 0.25,
|
|
|
|
stereo: bool = False,
|
|
|
|
device: str = "cuda",
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Compute a spectrogram image from a waveform.
|
|
|
|
"""
|
|
|
|
segment = pydub.AudioSegment.from_file(audio)
|
|
|
|
|
|
|
|
params = SpectrogramParams(
|
|
|
|
sample_rate=segment.frame_rate,
|
|
|
|
stereo=stereo,
|
|
|
|
window_duration_ms=window_duration_ms,
|
|
|
|
padded_duration_ms=padded_duration_ms,
|
|
|
|
step_size_ms=step_size_ms,
|
|
|
|
min_frequency=min_frequency,
|
|
|
|
max_frequency=max_frequency,
|
|
|
|
num_frequencies=num_frequencies,
|
|
|
|
power_for_image=power_for_image,
|
|
|
|
)
|
|
|
|
|
|
|
|
converter = SpectrogramImageConverter(params=params, device=device)
|
|
|
|
|
|
|
|
pil_image = converter.spectrogram_image_from_audio(segment)
|
|
|
|
|
|
|
|
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
|
|
|
|
print(f"Wrote {image}")
|
|
|
|
|
|
|
|
|
|
|
|
def print_exif(*, image: str) -> None:
|
|
|
|
"""
|
|
|
|
Print the params of a spectrogram image as saved in the exif data.
|
|
|
|
"""
|
|
|
|
pil_image = Image.open(image)
|
|
|
|
exif_data = image_util.exif_from_image(pil_image)
|
|
|
|
|
|
|
|
for name, value in exif_data.items():
|
|
|
|
print(f"{name:<20} = {value:>15}")
|
|
|
|
|
|
|
|
|
|
|
|
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
|
|
|
|
"""
|
|
|
|
Reconstruct an audio clip from a spectrogram image.
|
|
|
|
"""
|
|
|
|
pil_image = Image.open(image)
|
|
|
|
|
|
|
|
# Get parameters from image exif
|
|
|
|
img_exif = pil_image.getexif()
|
|
|
|
assert img_exif is not None
|
|
|
|
|
|
|
|
try:
|
|
|
|
params = SpectrogramParams.from_exif(exif=img_exif)
|
2023-01-02 22:49:06 -07:00
|
|
|
except (KeyError, AttributeError):
|
2022-12-26 18:28:09 -07:00
|
|
|
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
|
|
|
|
params = SpectrogramParams()
|
|
|
|
|
|
|
|
converter = SpectrogramImageConverter(params=params, device=device)
|
|
|
|
segment = converter.audio_from_spectrogram_image(pil_image)
|
|
|
|
|
|
|
|
extension = Path(audio).suffix[1:]
|
|
|
|
segment.export(audio, format=extension)
|
|
|
|
|
|
|
|
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
|
|
|
|
|
|
|
|
|
|
|
|
def sample_clips(
|
|
|
|
*,
|
|
|
|
audio: str,
|
|
|
|
output_dir: str,
|
|
|
|
num_clips: int = 1,
|
2023-01-28 11:46:04 -07:00
|
|
|
duration_ms: int = 5120,
|
2022-12-26 18:28:09 -07:00
|
|
|
mono: bool = False,
|
|
|
|
extension: str = "wav",
|
|
|
|
seed: int = -1,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Slice an audio file into clips of the given duration.
|
|
|
|
"""
|
|
|
|
if seed >= 0:
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
|
|
segment = pydub.AudioSegment.from_file(audio)
|
|
|
|
|
|
|
|
if mono:
|
|
|
|
segment = segment.set_channels(1)
|
|
|
|
|
|
|
|
output_dir_path = Path(output_dir)
|
|
|
|
if not output_dir_path.exists():
|
|
|
|
output_dir_path.mkdir(parents=True)
|
|
|
|
|
|
|
|
segment_duration_ms = int(segment.duration_seconds * 1000)
|
|
|
|
for i in range(num_clips):
|
|
|
|
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
|
|
|
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
|
|
|
|
|
|
|
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
|
|
|
clip_path = output_dir_path / clip_name
|
|
|
|
clip.export(clip_path, format=extension)
|
|
|
|
print(f"Wrote {clip_path}")
|
|
|
|
|
|
|
|
|
2023-01-28 11:46:04 -07:00
|
|
|
def audio_to_images_batch(
|
|
|
|
*,
|
|
|
|
audio_dir: str,
|
|
|
|
output_dir: str,
|
|
|
|
image_extension: str = "jpg",
|
|
|
|
step_size_ms: int = 10,
|
|
|
|
num_frequencies: int = 512,
|
|
|
|
min_frequency: int = 0,
|
|
|
|
max_frequency: int = 10000,
|
|
|
|
power_for_image: float = 0.25,
|
|
|
|
mono: bool = False,
|
|
|
|
sample_rate: int = 44100,
|
|
|
|
device: str = "cuda",
|
|
|
|
num_threads: T.Optional[int] = None,
|
|
|
|
limit: int = -1,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Process audio clips into spectrograms in batch, multi-threaded.
|
|
|
|
"""
|
|
|
|
audio_paths = list(Path(audio_dir).glob("*"))
|
|
|
|
audio_paths.sort()
|
|
|
|
|
|
|
|
if limit > 0:
|
|
|
|
audio_paths = audio_paths[:limit]
|
|
|
|
|
|
|
|
output_path = Path(output_dir)
|
|
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
params = SpectrogramParams(
|
|
|
|
step_size_ms=step_size_ms,
|
|
|
|
num_frequencies=num_frequencies,
|
|
|
|
min_frequency=min_frequency,
|
|
|
|
max_frequency=max_frequency,
|
|
|
|
power_for_image=power_for_image,
|
|
|
|
stereo=not mono,
|
|
|
|
sample_rate=sample_rate,
|
|
|
|
)
|
|
|
|
|
|
|
|
converter = SpectrogramImageConverter(params=params, device=device)
|
|
|
|
|
|
|
|
def process_one(audio_path: Path) -> None:
|
|
|
|
# Load
|
|
|
|
try:
|
|
|
|
segment = pydub.AudioSegment.from_file(str(audio_path))
|
|
|
|
except Exception:
|
|
|
|
return
|
|
|
|
|
|
|
|
# TODO(hayk): Sanity checks on clip
|
|
|
|
|
|
|
|
if mono and segment.channels != 1:
|
|
|
|
segment = segment.set_channels(1)
|
|
|
|
elif not mono and segment.channels != 2:
|
|
|
|
segment = segment.set_channels(2)
|
|
|
|
|
|
|
|
# Frame rate
|
|
|
|
if segment.frame_rate != params.sample_rate:
|
|
|
|
segment = segment.set_frame_rate(params.sample_rate)
|
|
|
|
|
|
|
|
# Convert
|
|
|
|
image = converter.spectrogram_image_from_audio(segment)
|
|
|
|
|
|
|
|
# Save
|
|
|
|
image_path = output_path / f"{audio_path.stem}.{image_extension}"
|
|
|
|
image_format = {"jpg": "JPEG", "jpeg": "JPEG", "png": "PNG"}[image_extension]
|
|
|
|
image.save(image_path, exif=image.getexif(), format=image_format)
|
|
|
|
|
|
|
|
# Create thread pool
|
|
|
|
pool = ThreadPool(processes=num_threads)
|
|
|
|
with tqdm.tqdm(total=len(audio_paths)) as pbar:
|
|
|
|
for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)):
|
|
|
|
pbar.update()
|
|
|
|
|
|
|
|
|
|
|
|
def sample_clips_batch(
|
|
|
|
*,
|
|
|
|
audio_dir: str,
|
|
|
|
output_dir: str,
|
|
|
|
num_clips_per_file: int = 1,
|
|
|
|
duration_ms: int = 5120,
|
|
|
|
mono: bool = False,
|
|
|
|
extension: str = "mp3",
|
|
|
|
num_threads: T.Optional[int] = None,
|
|
|
|
limit: int = -1,
|
|
|
|
seed: int = -1,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Sample short clips from a directory of audio files, multi-threaded.
|
|
|
|
"""
|
|
|
|
audio_paths = list(Path(audio_dir).glob("*"))
|
|
|
|
audio_paths.sort()
|
|
|
|
|
|
|
|
if limit > 0:
|
|
|
|
audio_paths = audio_paths[:limit]
|
|
|
|
|
|
|
|
output_path = Path(output_dir)
|
|
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
if seed >= 0:
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
def process_one(audio_path: Path) -> None:
|
|
|
|
try:
|
|
|
|
segment = pydub.AudioSegment.from_file(str(audio_path))
|
|
|
|
except Exception:
|
|
|
|
return
|
|
|
|
|
|
|
|
if mono:
|
|
|
|
segment = segment.set_channels(1)
|
|
|
|
|
|
|
|
segment_duration_ms = int(segment.duration_seconds * 1000)
|
|
|
|
for i in range(num_clips_per_file):
|
|
|
|
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
|
|
|
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
|
|
|
|
|
|
|
clip_name = (
|
|
|
|
f"{audio_path.stem}_{i}"
|
|
|
|
"start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
|
|
|
)
|
|
|
|
clip.export(output_path / clip_name, format=extension)
|
|
|
|
|
|
|
|
pool = ThreadPool(processes=num_threads)
|
|
|
|
with tqdm.tqdm(total=len(audio_paths)) as pbar:
|
|
|
|
for result in pool.imap_unordered(process_one, audio_paths):
|
|
|
|
# process_one(audio_path)
|
|
|
|
pbar.update()
|
|
|
|
|
|
|
|
# with tqdm.tqdm(total=len(audio_paths)) as pbar:
|
|
|
|
# for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)):
|
|
|
|
# pass
|
|
|
|
# pbar.update()
|
|
|
|
|
|
|
|
|
2022-12-26 18:28:09 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
argh.dispatch_commands(
|
|
|
|
[
|
|
|
|
audio_to_image,
|
|
|
|
image_to_audio,
|
|
|
|
sample_clips,
|
|
|
|
print_exif,
|
2023-01-28 11:46:04 -07:00
|
|
|
audio_to_images_batch,
|
|
|
|
sample_clips_batch,
|
2022-12-26 18:28:09 -07:00
|
|
|
]
|
|
|
|
)
|