riffusion-inference/riffusion/audio.py

183 lines
4.5 KiB
Python

"""
Audio processing tools to convert between spectrogram images and waveforms.
"""
import io
import numpy as np
from PIL import Image
import pydub
from scipy.io import wavfile
import torch
import torchaudio
def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
"""
Reconstruct a WAV audio clip from a spectrogram image.
"""
max_volume = 50
power_for_image = 0.25
Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
sample_rate = 44100 # [Hz]
clip_duration_ms = 5000 # [ms]
bins_per_image = 512
n_mels = 512
# FFT parameters
window_duration_ms = 100 # [ms]
padded_duration_ms = 400 # [ms]
step_size_ms = 10 # [ms]
# Derived parameters
num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
hop_length = int(step_size_ms / 1000.0 * sample_rate)
win_length = int(window_duration_ms / 1000.0 * sample_rate)
waveform = waveform_from_spectrogram(
Sxx=Sxx,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
num_samples=num_samples,
sample_rate=sample_rate,
mel_scale=True,
n_mels=n_mels,
max_mel_iters=200,
num_griffin_lim_iters=32,
)
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, waveform.astype(np.int16))
wav_bytes.seek(0)
return wav_bytes
def spectrogram_from_image(
image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
TODO(hayk): Add image_from_spectrogram and call this out as the reverse.
"""
# Convert to a numpy array of floats
data = np.array(image).astype(np.float32)
# Flip Y take a single channel
data = data[::-1, :, 0]
# Invert
data = 255 - data
# Rescale to max volume
data = data * max_volume / 255
# Reverse the power curve
data = np.power(data, 1 / power_for_image)
return data
def spectrogram_from_waveform(
waveform: np.ndarray,
sample_rate: int,
n_fft: int,
hop_length: int,
win_length: int,
mel_scale: bool = True,
n_mels: int = 512,
) -> np.ndarray:
"""
Compute a spectrogram from a waveform.
"""
spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=n_fft,
power=None,
hop_length=hop_length,
win_length=win_length,
)
waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]
Sxx_mag = np.abs(Sxx_complex)
if mel_scale:
mel_scaler = torchaudio.transforms.MelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
)
Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()
return Sxx_mag
def waveform_from_spectrogram(
Sxx: np.ndarray,
n_fft: int,
hop_length: int,
win_length: int,
num_samples: int,
sample_rate: int,
mel_scale: bool = True,
n_mels: int = 512,
max_mel_iters: int = 200,
num_griffin_lim_iters: int = 32,
device: str = "cuda:0",
) -> np.ndarray:
"""
Reconstruct a waveform from a spectrogram.
This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
to approximate the phase.
"""
Sxx_torch = torch.from_numpy(Sxx).to(device)
# TODO(hayk): Make this a class that caches the two things
if mel_scale:
mel_inv_scaler = torchaudio.transforms.InverseMelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
max_iter=max_mel_iters,
).to(device)
Sxx_torch = mel_inv_scaler(Sxx_torch)
griffin_lim = torchaudio.transforms.GriffinLim(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
power=1.0,
n_iter=num_griffin_lim_iters,
).to(device)
waveform = griffin_lim(Sxx_torch).cpu().numpy()
return waveform
def mp3_bytes_from_wav_bytes(wav_bytes: io.BytesIO) -> io.BytesIO:
mp3_bytes = io.BytesIO()
sound = pydub.AudioSegment.from_wav(wav_bytes)
sound.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
return mp3_bytes