diff --git a/requirements.txt b/requirements.txt index a857ac0..fd97636 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ accelerate argh dacite +demucs diffusers>=0.9.0 flask flask_cors diff --git a/riffusion/audio_splitter.py b/riffusion/audio_splitter.py index c2a8310..8d90ff1 100644 --- a/riffusion/audio_splitter.py +++ b/riffusion/audio_splitter.py @@ -1,4 +1,8 @@ +import shutil +import subprocess +import tempfile import typing as T +from pathlib import Path import numpy as np import pydub @@ -9,10 +13,65 @@ from torchaudio.transforms import Fade from riffusion.util import audio_util +def split_audio( + segment: pydub.AudioSegment, + model_name: str = "htdemucs_6s", + extension: str = "wav", + jobs: int = 4, + device: str = "cuda", +) -> T.Dict[str, pydub.AudioSegment]: + """ + Split audio into stems using demucs. + """ + tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_")) + + # Save the audio to a temporary file + audio_path = tmp_dir / "audio.mp3" + segment.export(audio_path, format="mp3") + + # Assemble command + command = [ + "demucs", + str(audio_path), + "--name", + model_name, + "--out", + str(tmp_dir), + "--jobs", + str(jobs), + "--device", + device if device != "mps" else "cpu", + ] + print(" ".join(command)) + + if extension == "mp3": + command.append("--mp3") + + # Run demucs + subprocess.run( + command, + check=True, + ) + + # Load the stems + stems = {} + for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"): + stem = pydub.AudioSegment.from_file(stem_path) + stems[stem_path.stem] = stem + + # Delete tmp dir + shutil.rmtree(tmp_dir) + + return stems + + class AudioSplitter: """ Split audio into instrument stems like {drums, bass, vocals, etc.} + NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer + model in the demucs repo. See the function above. Probably just delete this. + See: https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html """ diff --git a/riffusion/streamlit/pages/split_audio.py b/riffusion/streamlit/pages/split_audio.py index c4c5241..55da427 100644 --- a/riffusion/streamlit/pages/split_audio.py +++ b/riffusion/streamlit/pages/split_audio.py @@ -2,6 +2,7 @@ import io import streamlit as st +from riffusion.audio_splitter import split_audio from riffusion.streamlit import util as streamlit_util @@ -32,11 +33,13 @@ def render_split_audio() -> None: audio_file = st.file_uploader( "Upload audio", - type=["mp3", "m4a", "ogg", "wav", "flac"], + type=["mp3", "m4a", "ogg", "wav", "flac", "webm"], label_visibility="collapsed", ) - splitter = streamlit_util.get_audio_splitter(device=device) + recombine = st.sidebar.checkbox( + "Recombine", value=False, help="Show recombined audio at the end for comparison" + ) if not audio_file: st.info("Upload audio to get started") @@ -51,7 +54,7 @@ def render_split_audio() -> None: segment = streamlit_util.load_audio_file(audio_file) # Split - stems = splitter.split(segment) + stems = split_audio(segment, device=device) # Display each for name, stem in stems.items(): @@ -60,6 +63,18 @@ def render_split_audio() -> None: stem.export(audio_bytes, format="mp3") st.audio(audio_bytes) + if recombine: + stems_list = list(stems.values()) + recombined = stems_list[0] + for stem in stems_list[1:]: + recombined = recombined.overlay(stem) + + # Display + st.write("#### recombined") + audio_bytes = io.BytesIO() + recombined.export(audio_bytes, format="mp3") + st.audio(audio_bytes) + if __name__ == "__main__": render_split_audio()