2023-01-07 13:36:43 -07:00
|
|
|
import shutil
|
|
|
|
import subprocess
|
|
|
|
import tempfile
|
2023-01-04 21:43:44 -07:00
|
|
|
import typing as T
|
2023-01-07 13:36:43 -07:00
|
|
|
from pathlib import Path
|
2023-01-04 21:43:44 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pydub
|
|
|
|
import torch
|
|
|
|
import torchaudio
|
|
|
|
from torchaudio.transforms import Fade
|
|
|
|
|
|
|
|
from riffusion.util import audio_util
|
|
|
|
|
|
|
|
|
2023-01-07 13:36:43 -07:00
|
|
|
def split_audio(
|
|
|
|
segment: pydub.AudioSegment,
|
|
|
|
model_name: str = "htdemucs_6s",
|
|
|
|
extension: str = "wav",
|
|
|
|
jobs: int = 4,
|
|
|
|
device: str = "cuda",
|
|
|
|
) -> T.Dict[str, pydub.AudioSegment]:
|
|
|
|
"""
|
|
|
|
Split audio into stems using demucs.
|
|
|
|
"""
|
|
|
|
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
|
|
|
|
|
|
|
|
# Save the audio to a temporary file
|
|
|
|
audio_path = tmp_dir / "audio.mp3"
|
|
|
|
segment.export(audio_path, format="mp3")
|
|
|
|
|
|
|
|
# Assemble command
|
|
|
|
command = [
|
|
|
|
"demucs",
|
|
|
|
str(audio_path),
|
|
|
|
"--name",
|
|
|
|
model_name,
|
|
|
|
"--out",
|
|
|
|
str(tmp_dir),
|
|
|
|
"--jobs",
|
|
|
|
str(jobs),
|
|
|
|
"--device",
|
|
|
|
device if device != "mps" else "cpu",
|
|
|
|
]
|
|
|
|
print(" ".join(command))
|
|
|
|
|
|
|
|
if extension == "mp3":
|
|
|
|
command.append("--mp3")
|
|
|
|
|
|
|
|
# Run demucs
|
|
|
|
subprocess.run(
|
|
|
|
command,
|
|
|
|
check=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Load the stems
|
|
|
|
stems = {}
|
|
|
|
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
|
|
|
|
stem = pydub.AudioSegment.from_file(stem_path)
|
|
|
|
stems[stem_path.stem] = stem
|
|
|
|
|
|
|
|
# Delete tmp dir
|
|
|
|
shutil.rmtree(tmp_dir)
|
|
|
|
|
|
|
|
return stems
|
|
|
|
|
|
|
|
|
2023-01-04 21:43:44 -07:00
|
|
|
class AudioSplitter:
|
|
|
|
"""
|
|
|
|
Split audio into instrument stems like {drums, bass, vocals, etc.}
|
|
|
|
|
2023-01-07 13:36:43 -07:00
|
|
|
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
|
|
|
|
model in the demucs repo. See the function above. Probably just delete this.
|
|
|
|
|
2023-01-04 21:43:44 -07:00
|
|
|
See:
|
|
|
|
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
segment_length_s: float = 10.0,
|
|
|
|
overlap_s: float = 0.1,
|
|
|
|
device: str = "cuda",
|
|
|
|
):
|
|
|
|
self.segment_length_s = segment_length_s
|
|
|
|
self.overlap_s = overlap_s
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
self.model = self.load_model().to(device)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_model(model_path: str = "models/hdemucs_high_trained.pt") -> torchaudio.models.HDemucs:
|
|
|
|
"""
|
|
|
|
Load the trained HDEMUCS pytorch model.
|
|
|
|
"""
|
|
|
|
# NOTE(hayk): The sources are baked into the pretrained model and can't be changed
|
|
|
|
model = torchaudio.models.hdemucs_high(sources=["drums", "bass", "other", "vocals"])
|
|
|
|
|
|
|
|
path = torchaudio.utils.download_asset(model_path)
|
|
|
|
state_dict = torch.load(path)
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
def split(self, audio: pydub.AudioSegment) -> T.Dict[str, pydub.AudioSegment]:
|
|
|
|
"""
|
|
|
|
Split the given audio segment into instrument stems.
|
|
|
|
"""
|
|
|
|
if audio.channels == 1:
|
|
|
|
audio_stereo = audio.set_channels(2)
|
|
|
|
elif audio.channels == 2:
|
|
|
|
audio_stereo = audio
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Audio must be stereo, but got {audio.channels} channels")
|
|
|
|
|
|
|
|
# Get as (samples, channels) float numpy array
|
|
|
|
waveform_np = np.array(audio_stereo.get_array_of_samples())
|
|
|
|
waveform_np = waveform_np.reshape(-1, audio_stereo.channels)
|
|
|
|
waveform_np_float = waveform_np.astype(np.float32)
|
|
|
|
|
|
|
|
# To torch and channels-first
|
|
|
|
waveform = torch.from_numpy(waveform_np_float).to(self.device)
|
|
|
|
waveform = waveform.transpose(1, 0)
|
|
|
|
|
|
|
|
# Normalize
|
|
|
|
ref = waveform.mean(0)
|
|
|
|
waveform = (waveform - ref.mean()) / ref.std()
|
|
|
|
|
|
|
|
# Split
|
|
|
|
sources = self.separate_sources(
|
|
|
|
waveform[None],
|
|
|
|
sample_rate=audio.frame_rate,
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
# De-normalize
|
|
|
|
sources = sources * ref.std() + ref.mean()
|
|
|
|
|
|
|
|
# To numpy
|
|
|
|
sources_np = sources.cpu().numpy().astype(waveform_np.dtype)
|
|
|
|
|
|
|
|
# Convert to pydub
|
|
|
|
stem_segments = [
|
|
|
|
audio_util.audio_from_waveform(waveform, audio.frame_rate) for waveform in sources_np
|
|
|
|
]
|
|
|
|
|
|
|
|
# Convert back to mono if necessary
|
|
|
|
if audio.channels == 1:
|
|
|
|
stem_segments = [stem.set_channels(1) for stem in stem_segments]
|
|
|
|
|
|
|
|
return dict(zip(self.model.sources, stem_segments))
|
|
|
|
|
|
|
|
def separate_sources(
|
|
|
|
self,
|
|
|
|
waveform: torch.Tensor,
|
|
|
|
sample_rate: int = 44100,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Apply model to a given waveform in chunks. Use fade and overlap to smooth the edges.
|
|
|
|
"""
|
|
|
|
batch, channels, length = waveform.shape
|
|
|
|
|
|
|
|
chunk_len = int(sample_rate * self.segment_length_s * (1 + self.overlap_s))
|
|
|
|
start = 0
|
|
|
|
end = chunk_len
|
|
|
|
overlap_frames = self.overlap_s * sample_rate
|
|
|
|
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
|
|
|
|
|
|
|
|
final = torch.zeros(batch, len(self.model.sources), channels, length, device=self.device)
|
|
|
|
|
|
|
|
# TODO(hayk): Improve this code, which came from the torchaudio docs
|
|
|
|
while start < length - overlap_frames:
|
|
|
|
chunk = waveform[:, :, start:end]
|
|
|
|
with torch.no_grad():
|
|
|
|
out = self.model.forward(chunk)
|
|
|
|
out = fade(out)
|
|
|
|
final[:, :, :, start:end] += out
|
|
|
|
if start == 0:
|
|
|
|
fade.fade_in_len = int(overlap_frames)
|
|
|
|
start += int(chunk_len - overlap_frames)
|
|
|
|
else:
|
|
|
|
start += chunk_len
|
|
|
|
end += chunk_len
|
|
|
|
if end >= length:
|
|
|
|
fade.fade_out_len = 0
|
|
|
|
|
|
|
|
return final
|