From 45d36a32a6f8e4c7f3db2274065bd82018dafd1c Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Sat, 28 Jan 2023 18:46:04 +0000 Subject: [PATCH] Add audio_to_image_batch and sample_clips_batch Topic: batch_cli_commands --- riffusion/cli.py | 140 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) diff --git a/riffusion/cli.py b/riffusion/cli.py index 266d12b..fc83071 100644 --- a/riffusion/cli.py +++ b/riffusion/cli.py @@ -2,11 +2,15 @@ Command line tools for riffusion. """ +import random +import typing as T +from multiprocessing.pool import ThreadPool from pathlib import Path import argh import numpy as np import pydub +import tqdm from PIL import Image from riffusion.spectrogram_image_converter import SpectrogramImageConverter @@ -96,7 +100,7 @@ def sample_clips( audio: str, output_dir: str, num_clips: int = 1, - duration_ms: int = 5000, + duration_ms: int = 5120, mono: bool = False, extension: str = "wav", seed: int = -1, @@ -127,6 +131,138 @@ def sample_clips( print(f"Wrote {clip_path}") +def audio_to_images_batch( + *, + audio_dir: str, + output_dir: str, + image_extension: str = "jpg", + step_size_ms: int = 10, + num_frequencies: int = 512, + min_frequency: int = 0, + max_frequency: int = 10000, + power_for_image: float = 0.25, + mono: bool = False, + sample_rate: int = 44100, + device: str = "cuda", + num_threads: T.Optional[int] = None, + limit: int = -1, +): + """ + Process audio clips into spectrograms in batch, multi-threaded. + """ + audio_paths = list(Path(audio_dir).glob("*")) + audio_paths.sort() + + if limit > 0: + audio_paths = audio_paths[:limit] + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + params = SpectrogramParams( + step_size_ms=step_size_ms, + num_frequencies=num_frequencies, + min_frequency=min_frequency, + max_frequency=max_frequency, + power_for_image=power_for_image, + stereo=not mono, + sample_rate=sample_rate, + ) + + converter = SpectrogramImageConverter(params=params, device=device) + + def process_one(audio_path: Path) -> None: + # Load + try: + segment = pydub.AudioSegment.from_file(str(audio_path)) + except Exception: + return + + # TODO(hayk): Sanity checks on clip + + if mono and segment.channels != 1: + segment = segment.set_channels(1) + elif not mono and segment.channels != 2: + segment = segment.set_channels(2) + + # Frame rate + if segment.frame_rate != params.sample_rate: + segment = segment.set_frame_rate(params.sample_rate) + + # Convert + image = converter.spectrogram_image_from_audio(segment) + + # Save + image_path = output_path / f"{audio_path.stem}.{image_extension}" + image_format = {"jpg": "JPEG", "jpeg": "JPEG", "png": "PNG"}[image_extension] + image.save(image_path, exif=image.getexif(), format=image_format) + + # Create thread pool + pool = ThreadPool(processes=num_threads) + with tqdm.tqdm(total=len(audio_paths)) as pbar: + for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)): + pbar.update() + + +def sample_clips_batch( + *, + audio_dir: str, + output_dir: str, + num_clips_per_file: int = 1, + duration_ms: int = 5120, + mono: bool = False, + extension: str = "mp3", + num_threads: T.Optional[int] = None, + limit: int = -1, + seed: int = -1, +): + """ + Sample short clips from a directory of audio files, multi-threaded. + """ + audio_paths = list(Path(audio_dir).glob("*")) + audio_paths.sort() + + if limit > 0: + audio_paths = audio_paths[:limit] + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + if seed >= 0: + random.seed(seed) + + def process_one(audio_path: Path) -> None: + try: + segment = pydub.AudioSegment.from_file(str(audio_path)) + except Exception: + return + + if mono: + segment = segment.set_channels(1) + + segment_duration_ms = int(segment.duration_seconds * 1000) + for i in range(num_clips_per_file): + clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) + clip = segment[clip_start_ms : clip_start_ms + duration_ms] + + clip_name = ( + f"{audio_path.stem}_{i}" + "start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}" + ) + clip.export(output_path / clip_name, format=extension) + + pool = ThreadPool(processes=num_threads) + with tqdm.tqdm(total=len(audio_paths)) as pbar: + for result in pool.imap_unordered(process_one, audio_paths): + # process_one(audio_path) + pbar.update() + + # with tqdm.tqdm(total=len(audio_paths)) as pbar: + # for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)): + # pass + # pbar.update() + + if __name__ == "__main__": argh.dispatch_commands( [ @@ -134,5 +270,7 @@ if __name__ == "__main__": image_to_audio, sample_clips, print_exif, + audio_to_images_batch, + sample_clips_batch, ] )