diff --git a/riffusion/spectrogram_converter.py b/riffusion/spectrogram_converter.py new file mode 100644 index 0000000..9283e31 --- /dev/null +++ b/riffusion/spectrogram_converter.py @@ -0,0 +1,201 @@ +import numpy as np +import pydub +import torch +import torchaudio +import warnings + +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.util import audio_util +from riffusion.util import torch_util + + +class SpectrogramConverter: + """ + Convert between audio segments and spectrogram tensors using torchaudio. + + In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values + that represent the amplitude of the frequency at that time bucket (in the frequency domain). + Frequencies are given in the perceptul Mel scale defined by the params. A more specific term + used in some functions is "mel amplitudes". + + The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only + returns the amplitude, because the phase is chaotic and hard to learn. The function + `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which + approximates the phase information using the Griffin-Lim algorithm. + + Each channel in the audio is treated independently, and the spectrogram has a batch dimension + equal to the number of channels in the input audio segment. + + Both the Griffin Lim algorithm and the Mel scaling process are lossy. + + For more information, see https://pytorch.org/audio/stable/transforms.html + """ + + def __init__(self, params: SpectrogramParams, device: str = "cuda"): + self.p = params + + self.device = torch_util.check_device(device) + + if device.lower().startswith("mps"): + warnings.warn( + "WARNING: MPS does not support audio operations, falling back to CPU for them", + stacklevel=2, + ) + self.device = "cpu" + + # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html + self.spectrogram_func = torchaudio.transforms.Spectrogram( + n_fft=params.n_fft, + hop_length=params.hop_length, + win_length=params.win_length, + pad=0, + window_fn=torch.hann_window, + power=None, + normalized=False, + wkwargs=None, + center=True, + pad_mode="reflect", + onesided=True, + ).to(self.device) + + # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html + self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim( + n_fft=params.n_fft, + n_iter=params.num_griffin_lim_iters, + win_length=params.win_length, + hop_length=params.hop_length, + window_fn=torch.hann_window, + power=1.0, + wkwargs=None, + momentum=0.99, + length=None, + rand_init=True, + ).to(self.device) + + # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html + self.mel_scaler = torchaudio.transforms.MelScale( + n_mels=params.num_frequencies, + sample_rate=params.sample_rate, + f_min=params.min_frequency, + f_max=params.max_frequency, + n_stft=params.n_fft // 2 + 1, + norm=params.mel_scale_norm, + mel_scale=params.mel_scale_type, + ).to(self.device) + + # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html + self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale( + n_stft=params.n_fft // 2 + 1, + n_mels=params.num_frequencies, + sample_rate=params.sample_rate, + f_min=params.min_frequency, + f_max=params.max_frequency, + max_iter=params.max_mel_iters, + tolerance_loss=1e-5, + tolerance_change=1e-8, + sgdargs=None, + norm=params.mel_scale_norm, + mel_scale=params.mel_scale_type, + ).to(self.device) + + def spectrogram_from_audio( + self, + audio: pydub.AudioSegment, + ) -> np.ndarray: + """ + Compute a spectrogram from an audio segment. + + Args: + audio: Audio segment which must match the sample rate of the params + + Returns: + spectrogram: (channel, frequency, time) + """ + assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params" + + # Get the samples as a numpy array in (batch, samples) shape + waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()]) + + # Convert to floats if necessary + if waveform.dtype != np.float32: + waveform = waveform.astype(np.float32) + + waveform_tensor = torch.from_numpy(waveform).to(self.device) + amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor) + return amplitudes_mel.cpu().numpy() + + def audio_from_spectrogram( + self, + spectrogram: np.ndarray, + apply_filters: bool = True, + ) -> pydub.AudioSegment: + """ + Reconstruct an audio segment from a spectrogram. + + Args: + spectrogram: (batch, frequency, time) + apply_filters: Post-process with normalization and compression + + Returns: + audio: Audio segment with channels equal to the batch dimension + """ + # Move to device + amplitudes_mel = torch.from_numpy(spectrogram).to(self.device) + + # Reconstruct the waveform + waveform = self.waveform_from_mel_amplitudes(amplitudes_mel) + + # Convert to audio segment + segment = audio_util.audio_from_waveform( + samples=waveform.cpu().numpy(), + sample_rate=self.p.sample_rate, + # Normalize the waveform to the range [-1, 1] + normalize=True, + ) + + # Optionally apply post-processing filters + if apply_filters: + segment = audio_util.apply_filters(segment) + + return segment + + def mel_amplitudes_from_waveform( + self, + waveform: torch.Tensor, + ) -> torch.Tensor: + """ + Torch-only function to compute Mel-scale amplitudes from a waveform. + + Args: + waveform: (batch, samples) + + Returns: + amplitudes_mel: (batch, frequency, time) + """ + # Compute the complex-valued spectrogram + spectrogram_complex = self.spectrogram_func(waveform) + + # Take the magnitude + amplitudes = torch.abs(spectrogram_complex) + + # Convert to mel scale + return self.mel_scaler(amplitudes) + + def waveform_from_mel_amplitudes( + self, + amplitudes_mel: torch.Tensor, + ) -> torch.Tensor: + """ + Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes. + + Args: + amplitudes_mel: (batch, frequency, time) + + Returns: + waveform: (batch, samples) + """ + # Convert from mel scale to linear + amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel) + + # Run the approximate algorithm to compute the phase and recover the waveform + return self.inverse_spectrogram_func(amplitudes_linear) diff --git a/riffusion/util/fft_util.py b/riffusion/util/fft_util.py new file mode 100644 index 0000000..f7c682f --- /dev/null +++ b/riffusion/util/fft_util.py @@ -0,0 +1,60 @@ +""" +FFT tools to analyze frequency content of audio segments. This is not code for +dealing with spectrogram images, but for analysis of waveforms. +""" +import struct +import typing as T + +import numpy as np +import plotly.graph_objects as go +import pydub +from scipy.fft import rfft, rfftfreq + + +def plot_ffts( + segments: T.Dict[str, pydub.AudioSegment], + title: str = "FFT", + min_frequency: float = 20, + max_frequency: float = 20000, +) -> None: + """ + Plot an FFT analysis of the given audio segments. + """ + ffts = {name: compute_fft(seg) for name, seg in segments.items()} + + fig = go.Figure( + data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()], + layout={"title": title}, + ) + fig.update_xaxes( + range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)], + type="log", + title="Frequency", + ) + fig.update_yaxes(title="Value") + fig.show() + + +def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]: + """ + Compute the FFT of the given audio segment as a mono signal. + + Returns: + frequencies: FFT computed frequencies + amplitudes: Amplitudes of each frequency + """ + # Convert to mono if needed. + if sound.channels > 1: + sound = sound.set_channels(1) + + sample_rate = sound.frame_rate + + num_samples = int(sound.frame_count()) + samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data) + + fft_values = rfft(samples) + amplitudes = np.abs(fft_values) + + frequencies = rfftfreq(n=num_samples, d=1 / sample_rate) + + return frequencies, amplitudes diff --git a/test/spectrogram_converter_test.py b/test/spectrogram_converter_test.py new file mode 100644 index 0000000..4368b43 --- /dev/null +++ b/test/spectrogram_converter_test.py @@ -0,0 +1,86 @@ +import dataclasses +import typing as T + +import pydub + +from riffusion.spectrogram_converter import SpectrogramConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.util import fft_util + +from .test_case import TestCase + + +class SpectrogramConverterTest(TestCase): + """ + Test going from audio to spectrogram to audio, without converting to + an image, to check quality loss of the reconstruction. + + This test allows comparing multiple sets of spectrogram params by listening to output audio + and by plotting their FFTs. + """ + + # TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces. + + def test_round_trip(self) -> None: + audio_path = ( + self.TEST_DATA_PATH + / "tired_traveler" + / "clips" + / "clip_2_start_103694_ms_duration_5678_ms.wav" + ) + output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_") + + # Load up the audio file + segment = pydub.AudioSegment.from_file(audio_path) + + # Convert to mono if desired + use_stereo = False + if use_stereo: + assert segment.channels == 2 + else: + segment = segment.set_channels(1) + + # Define named sets of parameters + param_sets: T.Dict[str, SpectrogramParams] = {} + + param_sets["default"] = SpectrogramParams( + sample_rate=segment.frame_rate, + stereo=use_stereo, + step_size_ms=10, + min_frequency=20, + max_frequency=20000, + num_frequencies=512, + ) + + if self.DEBUG: + param_sets["freq_0_to_10k"] = dataclasses.replace( + param_sets["default"], + min_frequency=0, + max_frequency=10000, + ) + + segments: T.Dict[str, pydub.AudioSegment] = { + "original": segment, + } + for name, params in param_sets.items(): + converter = SpectrogramConverter(params=params, device=self.DEVICE) + spectrogram = converter.spectrogram_from_audio(segment) + segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True) + + # Save segments to disk + for name, segment in segments.items(): + audio_out = output_dir / f"{name}.wav" + segment.export(audio_out, format="wav") + print(f"Saved {audio_out}") + + # Check params + self.assertEqual(segments["default"].channels, 2 if use_stereo else 1) + self.assertEqual(segments["original"].channels, segments["default"].channels) + self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate) + self.assertEqual(segments["original"].sample_width, segments["default"].sample_width) + + # TODO(hayk): Test something more rigorous about the quality of the reconstruction. + + # If debugging, load up a browser tab plotting the FFTs + if self.DEBUG: + fft_util.plot_ffts(segments)