Audio splitting with demucs hybrid transformer model
Topic: audio_splitter_transformer
This commit is contained in:
parent
f8595d7b29
commit
8e87c133c8
|
@ -1,6 +1,7 @@
|
|||
accelerate
|
||||
argh
|
||||
dacite
|
||||
demucs
|
||||
diffusers>=0.9.0
|
||||
flask
|
||||
flask_cors
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
|
@ -9,10 +13,65 @@ from torchaudio.transforms import Fade
|
|||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def split_audio(
|
||||
segment: pydub.AudioSegment,
|
||||
model_name: str = "htdemucs_6s",
|
||||
extension: str = "wav",
|
||||
jobs: int = 4,
|
||||
device: str = "cuda",
|
||||
) -> T.Dict[str, pydub.AudioSegment]:
|
||||
"""
|
||||
Split audio into stems using demucs.
|
||||
"""
|
||||
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
|
||||
|
||||
# Save the audio to a temporary file
|
||||
audio_path = tmp_dir / "audio.mp3"
|
||||
segment.export(audio_path, format="mp3")
|
||||
|
||||
# Assemble command
|
||||
command = [
|
||||
"demucs",
|
||||
str(audio_path),
|
||||
"--name",
|
||||
model_name,
|
||||
"--out",
|
||||
str(tmp_dir),
|
||||
"--jobs",
|
||||
str(jobs),
|
||||
"--device",
|
||||
device if device != "mps" else "cpu",
|
||||
]
|
||||
print(" ".join(command))
|
||||
|
||||
if extension == "mp3":
|
||||
command.append("--mp3")
|
||||
|
||||
# Run demucs
|
||||
subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Load the stems
|
||||
stems = {}
|
||||
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
|
||||
stem = pydub.AudioSegment.from_file(stem_path)
|
||||
stems[stem_path.stem] = stem
|
||||
|
||||
# Delete tmp dir
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return stems
|
||||
|
||||
|
||||
class AudioSplitter:
|
||||
"""
|
||||
Split audio into instrument stems like {drums, bass, vocals, etc.}
|
||||
|
||||
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
|
||||
model in the demucs repo. See the function above. Probably just delete this.
|
||||
|
||||
See:
|
||||
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
|
||||
"""
|
||||
|
|
|
@ -2,6 +2,7 @@ import io
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.audio_splitter import split_audio
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
|
@ -32,11 +33,13 @@ def render_split_audio() -> None:
|
|||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload audio",
|
||||
type=["mp3", "m4a", "ogg", "wav", "flac"],
|
||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
splitter = streamlit_util.get_audio_splitter(device=device)
|
||||
recombine = st.sidebar.checkbox(
|
||||
"Recombine", value=False, help="Show recombined audio at the end for comparison"
|
||||
)
|
||||
|
||||
if not audio_file:
|
||||
st.info("Upload audio to get started")
|
||||
|
@ -51,7 +54,7 @@ def render_split_audio() -> None:
|
|||
segment = streamlit_util.load_audio_file(audio_file)
|
||||
|
||||
# Split
|
||||
stems = splitter.split(segment)
|
||||
stems = split_audio(segment, device=device)
|
||||
|
||||
# Display each
|
||||
for name, stem in stems.items():
|
||||
|
@ -60,6 +63,18 @@ def render_split_audio() -> None:
|
|||
stem.export(audio_bytes, format="mp3")
|
||||
st.audio(audio_bytes)
|
||||
|
||||
if recombine:
|
||||
stems_list = list(stems.values())
|
||||
recombined = stems_list[0]
|
||||
for stem in stems_list[1:]:
|
||||
recombined = recombined.overlay(stem)
|
||||
|
||||
# Display
|
||||
st.write("#### recombined")
|
||||
audio_bytes = io.BytesIO()
|
||||
recombined.export(audio_bytes, format="mp3")
|
||||
st.audio(audio_bytes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_split_audio()
|
||||
|
|
Loading…
Reference in New Issue