Audio splitting with demucs hybrid transformer model

Topic: audio_splitter_transformer
This commit is contained in:
Hayk Martiros 2023-01-07 20:36:43 +00:00
parent f8595d7b29
commit 8e87c133c8
3 changed files with 78 additions and 3 deletions

View File

@ -1,6 +1,7 @@
accelerate accelerate
argh argh
dacite dacite
demucs
diffusers>=0.9.0 diffusers>=0.9.0
flask flask
flask_cors flask_cors

View File

@ -1,4 +1,8 @@
import shutil
import subprocess
import tempfile
import typing as T import typing as T
from pathlib import Path
import numpy as np import numpy as np
import pydub import pydub
@ -9,10 +13,65 @@ from torchaudio.transforms import Fade
from riffusion.util import audio_util from riffusion.util import audio_util
def split_audio(
segment: pydub.AudioSegment,
model_name: str = "htdemucs_6s",
extension: str = "wav",
jobs: int = 4,
device: str = "cuda",
) -> T.Dict[str, pydub.AudioSegment]:
"""
Split audio into stems using demucs.
"""
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
# Save the audio to a temporary file
audio_path = tmp_dir / "audio.mp3"
segment.export(audio_path, format="mp3")
# Assemble command
command = [
"demucs",
str(audio_path),
"--name",
model_name,
"--out",
str(tmp_dir),
"--jobs",
str(jobs),
"--device",
device if device != "mps" else "cpu",
]
print(" ".join(command))
if extension == "mp3":
command.append("--mp3")
# Run demucs
subprocess.run(
command,
check=True,
)
# Load the stems
stems = {}
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
stem = pydub.AudioSegment.from_file(stem_path)
stems[stem_path.stem] = stem
# Delete tmp dir
shutil.rmtree(tmp_dir)
return stems
class AudioSplitter: class AudioSplitter:
""" """
Split audio into instrument stems like {drums, bass, vocals, etc.} Split audio into instrument stems like {drums, bass, vocals, etc.}
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
model in the demucs repo. See the function above. Probably just delete this.
See: See:
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
""" """

View File

@ -2,6 +2,7 @@ import io
import streamlit as st import streamlit as st
from riffusion.audio_splitter import split_audio
from riffusion.streamlit import util as streamlit_util from riffusion.streamlit import util as streamlit_util
@ -32,11 +33,13 @@ def render_split_audio() -> None:
audio_file = st.file_uploader( audio_file = st.file_uploader(
"Upload audio", "Upload audio",
type=["mp3", "m4a", "ogg", "wav", "flac"], type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
label_visibility="collapsed", label_visibility="collapsed",
) )
splitter = streamlit_util.get_audio_splitter(device=device) recombine = st.sidebar.checkbox(
"Recombine", value=False, help="Show recombined audio at the end for comparison"
)
if not audio_file: if not audio_file:
st.info("Upload audio to get started") st.info("Upload audio to get started")
@ -51,7 +54,7 @@ def render_split_audio() -> None:
segment = streamlit_util.load_audio_file(audio_file) segment = streamlit_util.load_audio_file(audio_file)
# Split # Split
stems = splitter.split(segment) stems = split_audio(segment, device=device)
# Display each # Display each
for name, stem in stems.items(): for name, stem in stems.items():
@ -60,6 +63,18 @@ def render_split_audio() -> None:
stem.export(audio_bytes, format="mp3") stem.export(audio_bytes, format="mp3")
st.audio(audio_bytes) st.audio(audio_bytes)
if recombine:
stems_list = list(stems.values())
recombined = stems_list[0]
for stem in stems_list[1:]:
recombined = recombined.overlay(stem)
# Display
st.write("#### recombined")
audio_bytes = io.BytesIO()
recombined.export(audio_bytes, format="mp3")
st.audio(audio_bytes)
if __name__ == "__main__": if __name__ == "__main__":
render_split_audio() render_split_audio()