Audio splitting with demucs hybrid transformer model
Topic: audio_splitter_transformer
This commit is contained in:
parent
f8595d7b29
commit
8e87c133c8
|
@ -1,6 +1,7 @@
|
||||||
accelerate
|
accelerate
|
||||||
argh
|
argh
|
||||||
dacite
|
dacite
|
||||||
|
demucs
|
||||||
diffusers>=0.9.0
|
diffusers>=0.9.0
|
||||||
flask
|
flask
|
||||||
flask_cors
|
flask_cors
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
import typing as T
|
import typing as T
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pydub
|
import pydub
|
||||||
|
@ -9,10 +13,65 @@ from torchaudio.transforms import Fade
|
||||||
from riffusion.util import audio_util
|
from riffusion.util import audio_util
|
||||||
|
|
||||||
|
|
||||||
|
def split_audio(
|
||||||
|
segment: pydub.AudioSegment,
|
||||||
|
model_name: str = "htdemucs_6s",
|
||||||
|
extension: str = "wav",
|
||||||
|
jobs: int = 4,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> T.Dict[str, pydub.AudioSegment]:
|
||||||
|
"""
|
||||||
|
Split audio into stems using demucs.
|
||||||
|
"""
|
||||||
|
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
|
||||||
|
|
||||||
|
# Save the audio to a temporary file
|
||||||
|
audio_path = tmp_dir / "audio.mp3"
|
||||||
|
segment.export(audio_path, format="mp3")
|
||||||
|
|
||||||
|
# Assemble command
|
||||||
|
command = [
|
||||||
|
"demucs",
|
||||||
|
str(audio_path),
|
||||||
|
"--name",
|
||||||
|
model_name,
|
||||||
|
"--out",
|
||||||
|
str(tmp_dir),
|
||||||
|
"--jobs",
|
||||||
|
str(jobs),
|
||||||
|
"--device",
|
||||||
|
device if device != "mps" else "cpu",
|
||||||
|
]
|
||||||
|
print(" ".join(command))
|
||||||
|
|
||||||
|
if extension == "mp3":
|
||||||
|
command.append("--mp3")
|
||||||
|
|
||||||
|
# Run demucs
|
||||||
|
subprocess.run(
|
||||||
|
command,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the stems
|
||||||
|
stems = {}
|
||||||
|
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
|
||||||
|
stem = pydub.AudioSegment.from_file(stem_path)
|
||||||
|
stems[stem_path.stem] = stem
|
||||||
|
|
||||||
|
# Delete tmp dir
|
||||||
|
shutil.rmtree(tmp_dir)
|
||||||
|
|
||||||
|
return stems
|
||||||
|
|
||||||
|
|
||||||
class AudioSplitter:
|
class AudioSplitter:
|
||||||
"""
|
"""
|
||||||
Split audio into instrument stems like {drums, bass, vocals, etc.}
|
Split audio into instrument stems like {drums, bass, vocals, etc.}
|
||||||
|
|
||||||
|
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
|
||||||
|
model in the demucs repo. See the function above. Probably just delete this.
|
||||||
|
|
||||||
See:
|
See:
|
||||||
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
|
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2,6 +2,7 @@ import io
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from riffusion.audio_splitter import split_audio
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,11 +33,13 @@ def render_split_audio() -> None:
|
||||||
|
|
||||||
audio_file = st.file_uploader(
|
audio_file = st.file_uploader(
|
||||||
"Upload audio",
|
"Upload audio",
|
||||||
type=["mp3", "m4a", "ogg", "wav", "flac"],
|
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
||||||
label_visibility="collapsed",
|
label_visibility="collapsed",
|
||||||
)
|
)
|
||||||
|
|
||||||
splitter = streamlit_util.get_audio_splitter(device=device)
|
recombine = st.sidebar.checkbox(
|
||||||
|
"Recombine", value=False, help="Show recombined audio at the end for comparison"
|
||||||
|
)
|
||||||
|
|
||||||
if not audio_file:
|
if not audio_file:
|
||||||
st.info("Upload audio to get started")
|
st.info("Upload audio to get started")
|
||||||
|
@ -51,7 +54,7 @@ def render_split_audio() -> None:
|
||||||
segment = streamlit_util.load_audio_file(audio_file)
|
segment = streamlit_util.load_audio_file(audio_file)
|
||||||
|
|
||||||
# Split
|
# Split
|
||||||
stems = splitter.split(segment)
|
stems = split_audio(segment, device=device)
|
||||||
|
|
||||||
# Display each
|
# Display each
|
||||||
for name, stem in stems.items():
|
for name, stem in stems.items():
|
||||||
|
@ -60,6 +63,18 @@ def render_split_audio() -> None:
|
||||||
stem.export(audio_bytes, format="mp3")
|
stem.export(audio_bytes, format="mp3")
|
||||||
st.audio(audio_bytes)
|
st.audio(audio_bytes)
|
||||||
|
|
||||||
|
if recombine:
|
||||||
|
stems_list = list(stems.values())
|
||||||
|
recombined = stems_list[0]
|
||||||
|
for stem in stems_list[1:]:
|
||||||
|
recombined = recombined.overlay(stem)
|
||||||
|
|
||||||
|
# Display
|
||||||
|
st.write("#### recombined")
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
recombined.export(audio_bytes, format="mp3")
|
||||||
|
st.audio(audio_bytes)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
render_split_audio()
|
render_split_audio()
|
||||||
|
|
Loading…
Reference in New Issue