From 539aafde3eb388eec61dcd7d152835eccac0a21e Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Mon, 26 Dec 2022 17:15:05 -0800 Subject: [PATCH] Pull out basic utilities into util package Topic: clean_rewrite --- riffusion/util/__init__.py | 0 riffusion/util/audio_util.py | 66 +++++++++++++++++++ riffusion/util/base64_util.py | 9 +++ riffusion/util/image_util.py | 118 ++++++++++++++++++++++++++++++++++ riffusion/util/torch_util.py | 48 ++++++++++++++ 5 files changed, 241 insertions(+) create mode 100644 riffusion/util/__init__.py create mode 100644 riffusion/util/audio_util.py create mode 100644 riffusion/util/base64_util.py create mode 100644 riffusion/util/image_util.py create mode 100644 riffusion/util/torch_util.py diff --git a/riffusion/util/__init__.py b/riffusion/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/riffusion/util/audio_util.py b/riffusion/util/audio_util.py new file mode 100644 index 0000000..251b5b8 --- /dev/null +++ b/riffusion/util/audio_util.py @@ -0,0 +1,66 @@ +""" +Audio utility functions. +""" + +import io + +import numpy as np +import pydub +from scipy.io import wavfile + + +def audio_from_waveform( + samples: np.ndarray, sample_rate: int, normalize: bool = False +) -> pydub.AudioSegment: + """ + Convert a numpy array of samples of a waveform to an audio segment. + """ + # Normalize volume to fit in int16 + if normalize: + samples *= np.iinfo(np.int16).max / np.max(np.abs(samples)) + + # Transpose and convert to int16 + samples = samples.transpose(1, 0) + samples = samples.astype(np.int16) + + # Write to the bytes of a WAV file + wav_bytes = io.BytesIO() + wavfile.write(wav_bytes, sample_rate, samples) + wav_bytes.seek(0) + + # Read into pydub + return pydub.AudioSegment.from_wav(wav_bytes) + + +def apply_filters(segment: pydub.AudioSegment) -> pydub.AudioSegment: + """ + Apply post-processing filters to the audio segment to compress it and + keep at a -10 dBFS level. + """ + # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end. + # TODO(hayk): Is this going to make audio unbalanced between sequential clips? + + segment = pydub.effects.normalize( + segment, + headroom=0.1, + ) + + segment = segment.apply_gain(-10 - segment.dBFS) + + segment = pydub.effects.compress_dynamic_range( + segment, + threshold=-20.0, + ratio=4.0, + attack=5.0, + release=50.0, + ) + + desired_db = -12 + segment = segment.apply_gain(desired_db - segment.dBFS) + + segment = pydub.effects.normalize( + segment, + headroom=0.1, + ) + + return segment diff --git a/riffusion/util/base64_util.py b/riffusion/util/base64_util.py new file mode 100644 index 0000000..84b455e --- /dev/null +++ b/riffusion/util/base64_util.py @@ -0,0 +1,9 @@ +import base64 +import io + + +def encode(buffer: io.BytesIO) -> str: + """ + Encode the given buffer as base64. + """ + return base64.encodebytes(buffer.getvalue()).decode("ascii") diff --git a/riffusion/util/image_util.py b/riffusion/util/image_util.py new file mode 100644 index 0000000..22e42fb --- /dev/null +++ b/riffusion/util/image_util.py @@ -0,0 +1,118 @@ +""" +Module for converting between spectrograms tensors and spectrogram images, as well as +general helpers for operating on pillow images. +""" +import typing as T + +import numpy as np +from PIL import Image + +from riffusion.spectrogram_params import SpectrogramParams + + +def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image: + """ + Compute a spectrogram image from a spectrogram magnitude array. + + This is the inverse of spectrogram_from_image, except for discretization error from + quantizing to uint8. + + Args: + spectrogram: (channels, frequency, time) + power: A power curve to apply to the spectrogram to preserve contrast + + Returns: + image: (frequency, time, channels) + """ + # Rescale to 0-1 + max_value = np.max(spectrogram) + data = spectrogram / max_value + + # Apply the power curve + data = np.power(data, power) + + # Rescale to 0-255 + data = data * 255 + + # Invert + data = 255 - data + + # Convert to uint8 + data = data.astype(np.uint8) + + # Munge channels into a PIL image + if data.shape[0] == 1: + # TODO(hayk): Do we want to write single channel to disk instead? + image = Image.fromarray(data[0], mode="L").convert("RGB") + elif data.shape[0] == 2: + data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0) + image = Image.fromarray(data, mode="RGB") + else: + raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}") + + # Flip Y + image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) + + return image + + +def spectrogram_from_image( + image: Image.Image, + power: float = 0.25, + stereo: bool = False, + max_value: float = 30e6, +) -> np.ndarray: + """ + Compute a spectrogram magnitude array from a spectrogram image. + + This is the inverse of image_from_spectrogram, except for discretization error from + quantizing to uint8. + + Args: + image: (frequency, time, channels) + power: The power curve applied to the spectrogram + stereo: Whether the spectrogram encodes stereo data + max_value: The max value of the original spectrogram. In practice doesn't matter. + + Returns: + spectrogram: (channels, frequency, time) + """ + # Flip Y + image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) + + # Munge channels into a numpy array of (channels, frequency, time) + data = np.array(image).transpose(2, 0, 1) + if stereo: + # Take the G and B channels as done in image_from_spectrogram + data = data[[1, 2], :, :] + else: + data = data[0:1, :, :] + + # Convert to floats + data = data.astype(np.float32) + + # Invert + data = 255 - data + + # Rescale to 0-1 + data = data / 255 + + # Reverse the power curve + data = np.power(data, 1 / power) + + # Rescale to max value + data = data * max_value + + return data + + +def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]: + """ + Get the EXIF data from a PIL image as a dict. + """ + exif = pil_image.getexif() + + if exif is None or len(exif) == 0: + return {} + + return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()} diff --git a/riffusion/util/torch_util.py b/riffusion/util/torch_util.py new file mode 100644 index 0000000..953ef5d --- /dev/null +++ b/riffusion/util/torch_util.py @@ -0,0 +1,48 @@ +import warnings + +import numpy as np +import torch + + +def check_device(device: str, backup: str = "cpu") -> str: + """ + Check that the device is valid and available. If not, + """ + cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available() + mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available() + + if cuda_not_found or mps_not_found: + warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3) + return backup + + return device + + +def slerp( + t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995 +) -> torch.Tensor: + """ + Helper function to spherically interpolate two arrays v1 v2. + """ + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + input_device = v0.device + v0 = v0.cpu().numpy() + v1 = v1.cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > dot_threshold: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(input_device) + + return v2