Merge pull request #36 from riffusion/hayk.mart/revup/main/clean_rewrite

Rewrite the codebase to be high quality
This commit is contained in:
Hayk Martiros 2022-12-26 17:50:51 -08:00 committed by GitHub
commit d0fe85a4db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 1963 additions and 576 deletions

6
.gitignore vendored
View File

@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# VSCode
.vscode
# Distribution / packaging
.Python
build/
@ -27,6 +30,9 @@ share/python-wheels/
*.egg
MANIFEST
# OSX cruft
.DS_Store
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.

6
CITATION Normal file
View File

@ -0,0 +1,6 @@
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}

View File

@ -22,7 +22,7 @@ Tested with Python 3.9 and diffusers 0.9.0.
To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds.
You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
```
conda create --name riffusion-inference python=3.9
@ -32,14 +32,16 @@ python -m pip install -r requirements.txt
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
You can open and save WAV files with pure python. For opening and saving non-wav files like mp3 you'll need ffmpeg or libav.
Guides:
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
## Run
## Run the model server
Start the Flask server:
```
python -m riffusion.server --port 3013 --host 127.0.0.1
python -m riffusion.server --host 127.0.0.1 --port 3013
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
@ -77,6 +79,52 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
}
```
Use the `--device` argument to specify the torch device to use.
`cuda` is recommended.
`cpu` works but is quite slow.
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
## Test
Tests live in the `test/` directory and are implemented with `unittest`.
To run all tests:
```
python -m unittest test/*_test.py
```
To run a single test:
```
python -m unittest test.audio_to_image_test
```
To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
```
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
```
To run a single test case:
```
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
```
To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with
`cpu`, `cuda`, and `mps` backends.
## Development
Install additional packages for dev with `pip install -r dev_requirements.txt`.
* Linter: `ruff`
* Formatter: `black`
* Type checker: `mypy`
These are configured in `pyproject.toml`.
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
## Citation
If you build on this work, please cite it as follows:

View File

@ -1,6 +1,7 @@
black
ipdb
isort
mypy
pylint
ruff
types-Flask-Cors
types-Pillow
types-requests

3
integrations/README.md Normal file
View File

@ -0,0 +1,3 @@
# Integrations
This package contains integrations of Riffusion into third party apps and deployments.

0
integrations/__init__.py Normal file
View File

84
integrations/baseten.py Normal file
View File

@ -0,0 +1,84 @@
"""
This file can be used to build a Truss for deployment with Baseten.
If used, it should be renamed to model.py and placed alongside the other
files from /riffusion in the standard /model directory of the Truss.
For more on the Truss file format, see https://truss.baseten.co/
"""
import typing as T
import torch
import dacite
from huggingface_hub import snapshot_download
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.server import compute_request
from riffusion.datatypes import InferenceInput
class Model:
"""
Baseten Truss model class for riffusion.
See: https://truss.baseten.co/reference/structure#model.py
"""
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._pipeline = None
self._vae = None
self.checkpoint_name = "riffusion/riffusion-model-v1"
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
def load(self):
"""
Load the model. Guaranteed to be called before `predict`.
"""
self._pipeline = RiffusionPipeline.load_checkpoint(
checkpoint=self.checkpoint_name,
use_traced_unet=True,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def preprocess(self, request: T.Dict) -> T.Dict:
"""
Incorporate pre-processing required by the model if desired here.
These might be feature transformations that are tightly coupled to the model.
"""
return request
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
"""
This is the main function that is called.
"""
assert self._pipeline is not None, "Model pipeline not loaded"
try:
inputs = dacite.from_dict(InferenceInput, request)
except dacite.exceptions.WrongTypeError as exception:
return str(exception), 400
except dacite.exceptions.MissingValueError as exception:
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = compute_request(
inputs=inputs,
pipeline=self._pipeline,
seed_images_dir=self._seed_images_dir,
)
return response
def postprocess(self, request: T.Dict) -> T.Dict:
"""
Incorporate post-processing required by the model if desired here.
"""
return request

87
pyproject.toml Normal file
View File

@ -0,0 +1,87 @@
[tool.black]
line-length = 100
[tool.ruff]
line-length = 100
# Which rules to run
select = [
# Pyflakes
"F",
# Pycodestyle
"E",
"W",
# isort
# "I001"
]
ignore = []
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
per-file-ignores = {}
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
# Assume Python 3.10.
target-version = "py310"
[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10
[tool.mypy]
python_version = "3.10"
[[tool.mypy.overrides]]
module = "argh.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "diffusers.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "plotly.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "pydub.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.fft.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.io.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "torchaudio.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "transformers.*"
ignore_missing_imports = true

View File

@ -6,9 +6,11 @@ flask
flask_cors
numpy
pillow
plotly
pydub
scipy
soundfile
streamlit
torch
torchaudio
transformers

View File

@ -1,213 +0,0 @@
"""
Audio processing tools to convert between spectrogram images and waveforms.
"""
import io
import typing as T
import numpy as np
from PIL import Image
import pydub
from scipy.io import wavfile
import torch
import torchaudio
def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
"""
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
"""
max_volume = 50
power_for_image = 0.25
Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
sample_rate = 44100 # [Hz]
clip_duration_ms = 5000 # [ms]
bins_per_image = 512
n_mels = 512
# FFT parameters
window_duration_ms = 100 # [ms]
padded_duration_ms = 400 # [ms]
step_size_ms = 10 # [ms]
# Derived parameters
num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
hop_length = int(step_size_ms / 1000.0 * sample_rate)
win_length = int(window_duration_ms / 1000.0 * sample_rate)
samples = waveform_from_spectrogram(
Sxx=Sxx,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
num_samples=num_samples,
sample_rate=sample_rate,
mel_scale=True,
n_mels=n_mels,
max_mel_iters=200,
num_griffin_lim_iters=32,
)
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
wav_bytes.seek(0)
duration_s = float(len(samples)) / sample_rate
return wav_bytes, duration_s
def spectrogram_from_image(
image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
TODO(hayk): Add image_from_spectrogram and call this out as the reverse.
"""
# Convert to a numpy array of floats
data = np.array(image).astype(np.float32)
# Flip Y take a single channel
data = data[::-1, :, 0]
# Invert
data = 255 - data
# Rescale to max volume
data = data * max_volume / 255
# Reverse the power curve
data = np.power(data, 1 / power_for_image)
return data
def image_from_spectrogram(
spectrogram: np.ndarray, max_volume: float = 50, power_for_image: float = 0.25
) -> Image.Image:
"""
Compute a spectrogram image from a spectrogram magnitude array.
"""
# Apply the power curve
data = np.power(spectrogram, power_for_image)
# Rescale to 0-1
data = data / np.max(data)
# Rescale to 0-255
data = data * 255
# Invert
data = 255 - data
# Convert to a PIL image
image = Image.fromarray(data.astype(np.uint8))
# Flip Y
image = image.transpose(Image.FLIP_TOP_BOTTOM)
# Convert to RGB
image = image.convert("RGB")
return image
def spectrogram_from_waveform(
waveform: np.ndarray,
sample_rate: int,
n_fft: int,
hop_length: int,
win_length: int,
mel_scale: bool = True,
n_mels: int = 512,
) -> np.ndarray:
"""
Compute a spectrogram from a waveform.
"""
spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=n_fft,
power=None,
hop_length=hop_length,
win_length=win_length,
)
waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]
Sxx_mag = np.abs(Sxx_complex)
if mel_scale:
mel_scaler = torchaudio.transforms.MelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
)
Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()
return Sxx_mag
def waveform_from_spectrogram(
Sxx: np.ndarray,
n_fft: int,
hop_length: int,
win_length: int,
num_samples: int,
sample_rate: int,
mel_scale: bool = True,
n_mels: int = 512,
max_mel_iters: int = 200,
num_griffin_lim_iters: int = 32,
device: str = "cuda:0",
) -> np.ndarray:
"""
Reconstruct a waveform from a spectrogram.
This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
to approximate the phase.
"""
Sxx_torch = torch.from_numpy(Sxx).to(device)
# TODO(hayk): Make this a class that caches the two things
if mel_scale:
mel_inv_scaler = torchaudio.transforms.InverseMelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
max_iter=max_mel_iters,
).to(device)
Sxx_torch = mel_inv_scaler(Sxx_torch)
griffin_lim = torchaudio.transforms.GriffinLim(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
power=1.0,
n_iter=num_griffin_lim_iters,
).to(device)
waveform = griffin_lim(Sxx_torch).cpu().numpy()
return waveform
def mp3_bytes_from_wav_bytes(wav_bytes: io.BytesIO) -> io.BytesIO:
mp3_bytes = io.BytesIO()
sound = pydub.AudioSegment.from_wav(wav_bytes)
sound.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
return mp3_bytes

View File

@ -1,175 +0,0 @@
"""
This file can be used to build a Truss for deployment with Baseten.
If used, it should be renamed to model.py and placed alongside the other
files from /riffusion in the standard /model directory of the Truss.
For more on the Truss file format, see https://truss.baseten.co/
"""
import base64
import dataclasses
import json
import io
from pathlib import Path
from typing import Dict, List
import PIL
import torch
import dacite
from huggingface_hub import hf_hub_download, snapshot_download
from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
from .datatypes import InferenceInput, InferenceOutput
from .riffusion_pipeline import RiffusionPipeline
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._model = None
self._vae = None
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download(
"riffusion/riffusion-model-v1", allow_patterns="*.png"
)
def load(self):
# Load Riffusion model here and assign to self._model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() == False:
# Use only if you don't have a GPU with fp16 support
self._model = RiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
else:
# Model loading the model with fp16. This will fail if ran without a GPU with fp16 support
pipe = RiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
revision="fp16",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
# Deliberately not implementing channels_Last as it resulted in slower inference pipeline
# pipe.unet.to(memory_format=torch.channels_last)
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Use traced unet from hf hub
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipe.unet = TracedUNet()
self._model = pipe
def preprocess(self, request: Dict) -> Dict:
"""
Incorporate pre-processing required by the model if desired here.
These might be feature transformations that are tightly coupled to the model.
"""
return request
def postprocess(self, request: Dict) -> Dict:
"""
Incorporate post-processing required by the model if desired here.
"""
return request
def predict(self, request: Dict) -> Dict[str, List]:
"""
This is the main function that is called.
"""
# Example request:
# {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}}
# Parse an InferenceInput dataclass from the payload
try:
inputs = dacite.from_dict(InferenceInput, request)
except dacite.exceptions.WrongTypeError as exception:
# logging.info(json_data)
return str(exception), 400
except dacite.exceptions.MissingValueError as exception:
# logging.info(json_data)
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = self.compute(inputs)
return response
def compute(self, inputs: InferenceInput) -> str:
"""
Does all the heavy lifting of the request.
"""
# Load the seed image by ID
init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png")
if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
if inputs.mask_image_id:
mask_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.mask_image_id}.png")
if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else:
mask_image = None
# Execute the model to get the spectrogram image
image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image)
# Reconstruct audio from the image
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
# Compute the output as base64 encoded strings
image_bytes = self.image_bytes_from_image(image, mode="JPEG")
# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + self.base64_encode(image_bytes),
audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes),
duration_s=duration_s,
)
return json.dumps(dataclasses.asdict(output))
def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.
"""
image_bytes = io.BytesIO()
image.save(image_bytes, mode)
image_bytes.seek(0)
return image_bytes
def base64_encode(self, buffer: io.BytesIO) -> str:
"""
Encode the given buffer as base64.
"""
return base64.encodebytes(buffer.getvalue()).decode("ascii")

141
riffusion/cli.py Normal file
View File

@ -0,0 +1,141 @@
"""
Command line tools for riffusion.
"""
from pathlib import Path
import argh
import numpy as np
from PIL import Image
import pydub
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
def audio_to_image(
*,
audio: str,
image: str,
step_size_ms: int = 10,
num_frequencies: int = 512,
min_frequency: int = 0,
max_frequency: int = 10000,
window_duration_ms: int = 100,
padded_duration_ms: int = 400,
power_for_image: float = 0.25,
stereo: bool = False,
device: str = "cuda",
):
"""
Compute a spectrogram image from a waveform.
"""
segment = pydub.AudioSegment.from_file(audio)
params = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=stereo,
window_duration_ms=window_duration_ms,
padded_duration_ms=padded_duration_ms,
step_size_ms=step_size_ms,
min_frequency=min_frequency,
max_frequency=max_frequency,
num_frequencies=num_frequencies,
power_for_image=power_for_image,
)
converter = SpectrogramImageConverter(params=params, device=device)
pil_image = converter.spectrogram_image_from_audio(segment)
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
print(f"Wrote {image}")
def print_exif(*, image: str) -> None:
"""
Print the params of a spectrogram image as saved in the exif data.
"""
pil_image = Image.open(image)
exif_data = image_util.exif_from_image(pil_image)
for name, value in exif_data.items():
print(f"{name:<20} = {value:>15}")
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
"""
Reconstruct an audio clip from a spectrogram image.
"""
pil_image = Image.open(image)
# Get parameters from image exif
img_exif = pil_image.getexif()
assert img_exif is not None
try:
params = SpectrogramParams.from_exif(exif=img_exif)
except KeyError:
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
converter = SpectrogramImageConverter(params=params, device=device)
segment = converter.audio_from_spectrogram_image(pil_image)
extension = Path(audio).suffix[1:]
segment.export(audio, format=extension)
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
def sample_clips(
*,
audio: str,
output_dir: str,
num_clips: int = 1,
duration_ms: int = 5000,
mono: bool = False,
extension: str = "wav",
seed: int = -1,
):
"""
Slice an audio file into clips of the given duration.
"""
if seed >= 0:
np.random.seed(seed)
segment = pydub.AudioSegment.from_file(audio)
if mono:
segment = segment.set_channels(1)
output_dir_path = Path(output_dir)
if not output_dir_path.exists():
output_dir_path.mkdir(parents=True)
# TODO(hayk): Might be a lot easier with pydub
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
clip_path = output_dir_path / clip_name
clip.export(clip_path, format=extension)
print(f"Wrote {clip_path}")
if __name__ == "__main__":
argh.dispatch_commands(
[
audio_to_image,
image_to_audio,
sample_clips,
print_exif,
]
)

View File

@ -1,6 +1,7 @@
"""
Data model for the riffusion API.
"""
from __future__ import annotations
from dataclasses import dataclass
import typing as T
@ -58,6 +59,7 @@ class InferenceOutput:
"""
Response from the model inference server.
"""
# base64 encoded spectrogram image as a JPEG
image: str

3
riffusion/external/README.md vendored Normal file
View File

@ -0,0 +1,3 @@
# external
This package contains scripts and tools from external sources.

0
riffusion/external/__init__.py vendored Normal file
View File

View File

@ -5,10 +5,13 @@ This code is taken from the diffusers community pipeline:
License: Apache 2.0
"""
import re
from typing import List, Optional, Union
# ruff: noqa
# mypy: ignore-errors
import logging
import re
import typing as T
import torch
from diffusers import StableDiffusionPipeline
@ -123,7 +126,7 @@ def parse_prompt_attention(text):
return res
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
@ -192,8 +195,8 @@ def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
no_boseos_middle: T.Optional[bool] = True,
) -> torch.FloatTensor:
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
@ -232,14 +235,14 @@ def get_unweighted_text_embeddings(
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
prompt: T.Union[str, T.List[str]],
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
max_embeddings_multiples: T.Optional[int] = 3,
no_boseos_middle: T.Optional[bool] = False,
skip_parsing: T.Optional[bool] = False,
skip_weighting: T.Optional[bool] = False,
**kwargs,
):
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
@ -248,9 +251,9 @@ def get_weighted_text_embeddings(
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
prompt (`str` or `T.List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
uncond_prompt (`str` or `T.List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
@ -269,8 +272,6 @@ def get_weighted_text_embeddings(
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
print(f"tokens: {prompt_tokens}")
print(f"weights: {prompt_weights}")
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):

View File

@ -1,13 +1,15 @@
"""
Riffusion inference pipeline.
"""
from __future__ import annotations
import dataclasses
import functools
import inspect
import typing as T
import numpy as np
import PIL
from PIL import Image
import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -15,9 +17,12 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from huggingface_hub import hf_hub_download
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from .datatypes import InferenceInput
from riffusion.datatypes import InferenceInput
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
from riffusion.util import torch_util
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -56,8 +61,110 @@ class RiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
@classmethod
def load_checkpoint(
cls,
checkpoint: str,
use_traced_unet: bool = True,
channels_last: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
Args:
checkpoint: Model checkpoint on disk in diffusers format
use_traced_unet: Whether to use the traced unet for speedups
device: Device to load the model on
channels_last: Whether to use channels_last memory format
"""
device = torch_util.check_device(device)
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
# Disable the NSFW filter, causes incorrect false positives
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
safety_checker=lambda images, **kwargs: (images, False),
# Optionally attempt to use less memory
low_cpu_mem_usage=False,
).to(device)
if channels_last:
pipeline.unet.to(memory_format=torch.channels_last)
# Optionally load a traced unet
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
traced_unet = cls.load_traced_unet(
checkpoint=checkpoint,
subfolder="unet_traced",
filename="unet_traced.pt",
in_channels=pipeline.unet.in_channels,
dtype=dtype,
device=device,
)
if traced_unet is not None:
pipeline.unet = traced_unet
model = pipeline.to(device)
return model
@staticmethod
def load_traced_unet(
checkpoint: str,
subfolder: str,
filename: str,
in_channels: int,
dtype: torch.dtype,
device: str = "cuda",
) -> T.Optional[torch.nn.Module]:
"""
Load a traced unet from the huggingface hub. This can improve performance.
"""
if device == "cpu" or device.lower().startswith("mps"):
print("WARNING: Traced UNet only available for CUDA, skipping")
return None
# Download and load the traced unet
unet_file = hf_hub_download(
checkpoint,
subfolder=subfolder,
filename=filename,
)
unet_traced = torch.jit.load(unet_file)
# Wrap it in a torch module
class TracedUNet(torch.nn.Module):
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
def __init__(self):
super().__init__()
self.in_channels = device
self.device = device
self.dtype = dtype
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return self.UNet2DConditionOutput(sample=sample)
return TracedUNet()
@property
def device(self) -> str:
return str(self.vae.device)
@functools.lru_cache()
def embed_text(self, text):
def embed_text(self, text) -> torch.FloatTensor:
"""
Takes in text and turns it into text embeddings.
"""
@ -73,12 +180,10 @@ class RiffusionPipeline(DiffusionPipeline):
return embed
@functools.lru_cache()
def embed_text_weighted(self, text):
def embed_text_weighted(self, text) -> torch.FloatTensor:
"""
Get text embedding with weights.
"""
from .prompt_weighting import get_weighted_text_embeddings
return get_weighted_text_embeddings(
pipe=self,
prompt=text,
@ -93,10 +198,10 @@ class RiffusionPipeline(DiffusionPipeline):
def riffuse(
self,
inputs: InferenceInput,
init_image: PIL.Image.Image,
mask_image: PIL.Image.Image = None,
init_image: Image.Image,
mask_image: T.Optional[Image.Image] = None,
use_reweighting: bool = True,
) -> PIL.Image.Image:
) -> Image.Image:
"""
Runs inference using interpolation with both img2img and text conditioning.
@ -113,8 +218,14 @@ class RiffusionPipeline(DiffusionPipeline):
end = inputs.end
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
# TODO(hayk): Always generate the seed on CPU?
if self.device.lower().startswith("mps"):
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
else:
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
# Text encodings
if use_reweighting:
@ -123,25 +234,31 @@ class RiffusionPipeline(DiffusionPipeline):
else:
embed_start = self.embed_text(start.prompt)
embed_end = self.embed_text(end.prompt)
text_embedding = torch.lerp(embed_start, embed_end, alpha)
text_embedding = embed_start + alpha * (embed_end - embed_start)
# Image latents
init_image = preprocess_image(init_image)
init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype)
init_image_torch = preprocess_image(init_image).to(
device=self.device, dtype=embed_start.dtype
)
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
# result is so close no matter the seed that it doesn't really add variety.
generator = torch.Generator(device=self.device).manual_seed(start.seed)
if self.device.lower().startswith("mps"):
generator = torch.Generator(device="cpu").manual_seed(start.seed)
else:
generator = torch.Generator(device=self.device).manual_seed(start.seed)
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Prepare mask latent
mask: T.Optional[torch.Tensor] = None
if mask_image:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor)
mask = mask_image.to(device=self.device, dtype=embed_start.dtype)
else:
mask = None
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
device=self.device, dtype=embed_start.dtype
)
outputs = self.interpolate_img2img(
text_embeddings=text_embedding,
@ -161,18 +278,18 @@ class RiffusionPipeline(DiffusionPipeline):
@torch.no_grad()
def interpolate_img2img(
self,
text_embeddings: torch.FloatTensor,
init_latents: torch.FloatTensor,
text_embeddings: torch.Tensor,
init_latents: torch.Tensor,
generator_a: torch.Generator,
generator_b: torch.Generator,
interpolate_alpha: float,
mask: T.Optional[torch.FloatTensor] = None,
mask: T.Optional[torch.Tensor] = None,
strength_a: float = 0.8,
strength_b: float = 0.8,
num_inference_steps: T.Optional[int] = 50,
guidance_scale: T.Optional[float] = 7.5,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
num_images_per_prompt: T.Optional[int] = 1,
num_images_per_prompt: int = 1,
eta: T.Optional[float] = 0.0,
output_type: T.Optional[str] = "pil",
**kwargs,
@ -198,11 +315,6 @@ class RiffusionPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
if negative_prompt is None:
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
@ -251,11 +363,11 @@ class RiffusionPipeline(DiffusionPipeline):
noise_b = torch.randn(
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
)
noise = slerp(interpolate_alpha, noise_a, noise_b)
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
init_latents_orig = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
@ -295,7 +407,9 @@ class RiffusionPipeline(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
init_latents_proper = self.scheduler.add_noise(
init_latents_orig, noise, torch.tensor([t])
)
# import ipdb; ipdb.set_trace()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
@ -311,62 +425,42 @@ class RiffusionPipeline(DiffusionPipeline):
return dict(images=image, latents=latents, nsfw_content_detected=False)
def preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess an image for the model.
"""
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
image = image.resize((w, h), resample=Image.LANCZOS)
image_np = np.array(image).astype(np.float32) / 255.0
image_np = image_np[None].transpose(0, 3, 1, 2)
image_torch = torch.from_numpy(image_np)
return 2.0 * image_torch - 1.0
def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor:
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
"""
Preprocess a mask for the model.
"""
# Convert to grayscale
mask = mask.convert("L")
# Resize to integer multiple of 32
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize(
(w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST
)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
w, h = map(lambda x: x - x % 32, (w, h))
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
return mask
# Convert to numpy array and rescale
mask_np = np.array(mask).astype(np.float32) / 255.0
# Tile and transpose
mask_np = np.tile(mask_np, (4, 1, 1))
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
def slerp(t, v0, v1, dot_threshold=0.9995):
"""
Helper function to spherically interpolate two arrays v1 v2.
"""
# Invert to repaint white and keep black
mask_np = 1 - mask_np # repaint white, keep black
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > dot_threshold:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
return torch.from_numpy(mask_np)

View File

@ -1,8 +1,7 @@
"""
Inference server for the riffusion project.
Flask server that serves the riffusion model as an API.
"""
import base64
import dataclasses
import logging
import io
@ -16,15 +15,13 @@ import flask
from flask_cors import CORS
import PIL
import torch
from huggingface_hub import hf_hub_download
from .audio import wav_bytes_from_spectrogram_image
from .audio import mp3_bytes_from_wav_bytes
from .datatypes import InferenceInput
from .datatypes import InferenceOutput
from .riffusion_pipeline import RiffusionPipeline
from riffusion.datatypes import InferenceInput
from riffusion.datatypes import InferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import base64_util
# Flask app with CORS
app = flask.Flask(__name__)
@ -35,7 +32,7 @@ logging.basicConfig(level=logging.INFO)
logging.getLogger().addHandler(logging.FileHandler("server.log"))
# Global variable for the model pipeline
MODEL = None
PIPELINE: T.Optional[RiffusionPipeline] = None
# Where built-in seed images are stored
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
@ -45,8 +42,9 @@ def run_app(
*,
checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
device: str = "cuda",
host: str = "127.0.0.1",
port: int = 3000,
port: int = 3013,
debug: bool = False,
ssl_certificate: T.Optional[str] = None,
ssl_key: T.Optional[str] = None,
@ -55,8 +53,12 @@ def run_app(
Run a flask API that serves the given riffusion model checkpoint.
"""
# Initialize the model
global MODEL
MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)
global PIPELINE
PIPELINE = RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=not no_traced_unet,
device=device,
)
args = dict(
debug=debug,
@ -69,51 +71,7 @@ def run_app(
assert ssl_key is not None
args["ssl_context"] = (ssl_certificate, ssl_key)
app.run(**args)
def load_model(checkpoint: str, traced_unet: bool = True):
"""
Load the riffusion model pipeline.
"""
assert torch.cuda.is_available()
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
).to("cuda")
# Set the traced unet if desired
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Using traced unet from hf hub
unet_file = hf_hub_download(
checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = model.unet.in_channels
self.device = model.unet.device
self.dtype = torch.float16
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
model.unet = TracedUNet()
model = model.to("cuda")
return model
app.run(**args) # type: ignore
@app.route("/run_inference/", methods=["POST"])
@ -145,7 +103,11 @@ def run_inference():
logging.info(json_data)
return str(exception), 400
response = compute(inputs)
response = compute_request(
inputs=inputs,
seed_images_dir=SEED_IMAGES_DIR,
pipeline=PIPELINE,
)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
@ -153,60 +115,73 @@ def run_inference():
return response
def compute(inputs: InferenceInput) -> str:
def compute_request(
inputs: InferenceInput,
pipeline: RiffusionPipeline,
seed_images_dir: str,
) -> T.Union[str, T.Tuple[str, int]]:
"""
Does all the heavy lifting of the request.
Args:
inputs: The input dataclass
pipeline: The riffusion model pipeline
seed_images_dir: The directory where seed images are stored
"""
# Load the seed image by ID
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
mask_image: T.Optional[PIL.Image.Image] = None
if inputs.mask_image_id:
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else:
mask_image = None
# Execute the model to get the spectrogram image
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=mask_image,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct audio from the image
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
# TODO(hayk): It may help performance to cache this object
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
segment = converter.audio_from_spectrogram_image(
image,
apply_filters=True,
)
# Compute the output as base64 encoded strings
image_bytes = image_bytes_from_image(image, mode="JPEG")
# Export audio to MP3 bytes
mp3_bytes = io.BytesIO()
segment.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
# Export image to JPEG bytes
image_bytes = io.BytesIO()
image.save(image_bytes, exif=image.getexif(), format="JPEG")
image_bytes.seek(0)
# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + base64_encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
duration_s=duration_s,
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
duration_s=segment.duration_seconds,
)
return flask.jsonify(dataclasses.asdict(output))
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.
"""
image_bytes = io.BytesIO()
image.save(image_bytes, mode)
image_bytes.seek(0)
return image_bytes
def base64_encode(buffer: io.BytesIO) -> str:
"""
Encode the given buffer as base64.
"""
return base64.encodebytes(buffer.getvalue()).decode("ascii")
return json.dumps(dataclasses.asdict(output))
if __name__ == "__main__":

View File

@ -0,0 +1,201 @@
import numpy as np
import pydub
import torch
import torchaudio
import warnings
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import audio_util
from riffusion.util import torch_util
class SpectrogramConverter:
"""
Convert between audio segments and spectrogram tensors using torchaudio.
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
used in some functions is "mel amplitudes".
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
returns the amplitude, because the phase is chaotic and hard to learn. The function
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
approximates the phase information using the Griffin-Lim algorithm.
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
equal to the number of channels in the input audio segment.
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
For more information, see https://pytorch.org/audio/stable/transforms.html
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = torch_util.check_device(device)
if device.lower().startswith("mps"):
warnings.warn(
"WARNING: MPS does not support audio operations, falling back to CPU for them",
stacklevel=2,
)
self.device = "cpu"
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
self.spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=params.n_fft,
hop_length=params.hop_length,
win_length=params.win_length,
pad=0,
window_fn=torch.hann_window,
power=None,
normalized=False,
wkwargs=None,
center=True,
pad_mode="reflect",
onesided=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
n_fft=params.n_fft,
n_iter=params.num_griffin_lim_iters,
win_length=params.win_length,
hop_length=params.hop_length,
window_fn=torch.hann_window,
power=1.0,
wkwargs=None,
momentum=0.99,
length=None,
rand_init=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
self.mel_scaler = torchaudio.transforms.MelScale(
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
n_stft=params.n_fft // 2 + 1,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=params.n_fft // 2 + 1,
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
max_iter=params.max_mel_iters,
tolerance_loss=1e-5,
tolerance_change=1e-8,
sgdargs=None,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
def spectrogram_from_audio(
self,
audio: pydub.AudioSegment,
) -> np.ndarray:
"""
Compute a spectrogram from an audio segment.
Args:
audio: Audio segment which must match the sample rate of the params
Returns:
spectrogram: (channel, frequency, time)
"""
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
# Get the samples as a numpy array in (batch, samples) shape
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
# Convert to floats if necessary
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
waveform_tensor = torch.from_numpy(waveform).to(self.device)
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
return amplitudes_mel.cpu().numpy()
def audio_from_spectrogram(
self,
spectrogram: np.ndarray,
apply_filters: bool = True,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram.
Args:
spectrogram: (batch, frequency, time)
apply_filters: Post-process with normalization and compression
Returns:
audio: Audio segment with channels equal to the batch dimension
"""
# Move to device
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
# Reconstruct the waveform
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
# Convert to audio segment
segment = audio_util.audio_from_waveform(
samples=waveform.cpu().numpy(),
sample_rate=self.p.sample_rate,
# Normalize the waveform to the range [-1, 1]
normalize=True,
)
# Optionally apply post-processing filters
if apply_filters:
segment = audio_util.apply_filters(segment)
return segment
def mel_amplitudes_from_waveform(
self,
waveform: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to compute Mel-scale amplitudes from a waveform.
Args:
waveform: (batch, samples)
Returns:
amplitudes_mel: (batch, frequency, time)
"""
# Compute the complex-valued spectrogram
spectrogram_complex = self.spectrogram_func(waveform)
# Take the magnitude
amplitudes = torch.abs(spectrogram_complex)
# Convert to mel scale
return self.mel_scaler(amplitudes)
def waveform_from_mel_amplitudes(
self,
amplitudes_mel: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
Args:
amplitudes_mel: (batch, frequency, time)
Returns:
waveform: (batch, samples)
"""
# Convert from mel scale to linear
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
# Run the approximate algorithm to compute the phase and recover the waveform
return self.inverse_spectrogram_func(amplitudes_linear)

View File

@ -0,0 +1,91 @@
import numpy as np
from PIL import Image
import pydub
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
class SpectrogramImageConverter:
"""
Convert between spectrogram images and audio segments.
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
to images and back. The real audio processing lives in SpectrogramConverter.
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = device
self.converter = SpectrogramConverter(params=params, device=device)
def spectrogram_image_from_audio(
self,
segment: pydub.AudioSegment,
) -> Image.Image:
"""
Compute a spectrogram image from an audio segment.
Args:
segment: Audio segment to convert
Returns:
Spectrogram image (in pillow format)
"""
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
if self.p.stereo:
if segment.channels == 1:
print("WARNING: Mono audio but stereo=True, cloning channel")
segment = segment.set_channels(2)
elif segment.channels > 2:
print("WARNING: Multi channel audio, reducing to stereo")
segment = segment.set_channels(2)
else:
if segment.channels > 1:
print("WARNING: Stereo audio but stereo=False, setting to mono")
segment = segment.set_channels(1)
spectrogram = self.converter.spectrogram_from_audio(segment)
image = image_util.image_from_spectrogram(
spectrogram,
power=self.p.power_for_image,
)
# Store conversion params in exif metadata of the image
exif_data = self.p.to_exif()
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
exif = image.getexif()
exif.update(exif_data.items())
return image
def audio_from_spectrogram_image(
self,
image: Image.Image,
apply_filters: bool = True,
max_value: float = 30e6,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram image.
Args:
image: Spectrogram image (in pillow format)
apply_filters: Apply post-processing to improve the reconstructed audio
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
"""
spectrogram = image_util.spectrogram_from_image(
image,
max_value=max_value,
power=self.p.power_for_image,
stereo=self.p.stereo,
)
segment = self.converter.audio_from_spectrogram(
spectrogram,
apply_filters=apply_filters,
)
return segment

View File

@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
import typing as T
@dataclass(frozen=True)
class SpectrogramParams:
"""
Parameters for the conversion from audio to spectrograms to images and back.
Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
within spectrogram images.
"""
# Whether the audio is stereo or mono
stereo: bool = False
# FFT parameters
sample_rate: int = 44100
step_size_ms: int = 10
window_duration_ms: int = 100
padded_duration_ms: int = 400
# Mel scale parameters
num_frequencies: int = 512
# TODO(hayk): Set these to [20, 20000] for newer models
min_frequency: int = 0
max_frequency: int = 10000
mel_scale_norm: T.Optional[str] = None
mel_scale_type: str = "htk"
max_mel_iters: int = 200
# Griffin Lim parameters
num_griffin_lim_iters: int = 32
# Image parameterization
power_for_image: float = 0.25
class ExifTags(Enum):
"""
Custom EXIF tags for the spectrogram image.
"""
SAMPLE_RATE = 11000
STEREO = 11005
STEP_SIZE_MS = 11010
WINDOW_DURATION_MS = 11020
PADDED_DURATION_MS = 11030
NUM_FREQUENCIES = 11040
MIN_FREQUENCY = 11050
MAX_FREQUENCY = 11060
POWER_FOR_IMAGE = 11070
MAX_VALUE = 11080
@property
def n_fft(self) -> int:
"""
The number of samples in each STFT window, with padding.
"""
return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
@property
def win_length(self) -> int:
"""
The number of samples in each STFT window.
"""
return int(self.window_duration_ms / 1000.0 * self.sample_rate)
@property
def hop_length(self) -> int:
"""
The number of samples between each STFT window.
"""
return int(self.step_size_ms / 1000.0 * self.sample_rate)
def to_exif(self) -> T.Dict[int, T.Any]:
"""
Return a dictionary of EXIF tags for the current values.
"""
return {
self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
self.ExifTags.STEREO.value: self.stereo,
self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
}
@classmethod
def from_exif(cls, exif: T.Mapping[int, T.Any]) -> SpectrogramParams:
"""
Create a SpectrogramParams object from the EXIF tags of the given image.
"""
# TODO(hayk): Handle missing tags
return cls(
sample_rate=exif[cls.ExifTags.SAMPLE_RATE.value],
stereo=bool(exif[cls.ExifTags.STEREO.value]),
step_size_ms=exif[cls.ExifTags.STEP_SIZE_MS.value],
window_duration_ms=exif[cls.ExifTags.WINDOW_DURATION_MS.value],
padded_duration_ms=exif[cls.ExifTags.PADDED_DURATION_MS.value],
num_frequencies=exif[cls.ExifTags.NUM_FREQUENCIES.value],
min_frequency=exif[cls.ExifTags.MIN_FREQUENCY.value],
max_frequency=exif[cls.ExifTags.MAX_FREQUENCY.value],
power_for_image=exif[cls.ExifTags.POWER_FOR_IMAGE.value],
)

View File

View File

@ -0,0 +1,66 @@
"""
Audio utility functions.
"""
import io
import numpy as np
import pydub
from scipy.io import wavfile
def audio_from_waveform(
samples: np.ndarray, sample_rate: int, normalize: bool = False
) -> pydub.AudioSegment:
"""
Convert a numpy array of samples of a waveform to an audio segment.
"""
# Normalize volume to fit in int16
if normalize:
samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
# Transpose and convert to int16
samples = samples.transpose(1, 0)
samples = samples.astype(np.int16)
# Write to the bytes of a WAV file
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, samples)
wav_bytes.seek(0)
# Read into pydub
return pydub.AudioSegment.from_wav(wav_bytes)
def apply_filters(segment: pydub.AudioSegment) -> pydub.AudioSegment:
"""
Apply post-processing filters to the audio segment to compress it and
keep at a -10 dBFS level.
"""
# TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
# TODO(hayk): Is this going to make audio unbalanced between sequential clips?
segment = pydub.effects.normalize(
segment,
headroom=0.1,
)
segment = segment.apply_gain(-10 - segment.dBFS)
segment = pydub.effects.compress_dynamic_range(
segment,
threshold=-20.0,
ratio=4.0,
attack=5.0,
release=50.0,
)
desired_db = -12
segment = segment.apply_gain(desired_db - segment.dBFS)
segment = pydub.effects.normalize(
segment,
headroom=0.1,
)
return segment

View File

@ -0,0 +1,9 @@
import base64
import io
def encode(buffer: io.BytesIO) -> str:
"""
Encode the given buffer as base64.
"""
return base64.encodebytes(buffer.getvalue()).decode("ascii")

View File

@ -0,0 +1,60 @@
"""
FFT tools to analyze frequency content of audio segments. This is not code for
dealing with spectrogram images, but for analysis of waveforms.
"""
import struct
import typing as T
import numpy as np
import plotly.graph_objects as go
import pydub
from scipy.fft import rfft, rfftfreq
def plot_ffts(
segments: T.Dict[str, pydub.AudioSegment],
title: str = "FFT",
min_frequency: float = 20,
max_frequency: float = 20000,
) -> None:
"""
Plot an FFT analysis of the given audio segments.
"""
ffts = {name: compute_fft(seg) for name, seg in segments.items()}
fig = go.Figure(
data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
layout={"title": title},
)
fig.update_xaxes(
range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
type="log",
title="Frequency",
)
fig.update_yaxes(title="Value")
fig.show()
def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
"""
Compute the FFT of the given audio segment as a mono signal.
Returns:
frequencies: FFT computed frequencies
amplitudes: Amplitudes of each frequency
"""
# Convert to mono if needed.
if sound.channels > 1:
sound = sound.set_channels(1)
sample_rate = sound.frame_rate
num_samples = int(sound.frame_count())
samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
fft_values = rfft(samples)
amplitudes = np.abs(fft_values)
frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
return frequencies, amplitudes

View File

@ -0,0 +1,118 @@
"""
Module for converting between spectrograms tensors and spectrogram images, as well as
general helpers for operating on pillow images.
"""
import typing as T
import numpy as np
from PIL import Image
from riffusion.spectrogram_params import SpectrogramParams
def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
"""
Compute a spectrogram image from a spectrogram magnitude array.
This is the inverse of spectrogram_from_image, except for discretization error from
quantizing to uint8.
Args:
spectrogram: (channels, frequency, time)
power: A power curve to apply to the spectrogram to preserve contrast
Returns:
image: (frequency, time, channels)
"""
# Rescale to 0-1
max_value = np.max(spectrogram)
data = spectrogram / max_value
# Apply the power curve
data = np.power(data, power)
# Rescale to 0-255
data = data * 255
# Invert
data = 255 - data
# Convert to uint8
data = data.astype(np.uint8)
# Munge channels into a PIL image
if data.shape[0] == 1:
# TODO(hayk): Do we want to write single channel to disk instead?
image = Image.fromarray(data[0], mode="L").convert("RGB")
elif data.shape[0] == 2:
data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
image = Image.fromarray(data, mode="RGB")
else:
raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
# Flip Y
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
return image
def spectrogram_from_image(
image: Image.Image,
power: float = 0.25,
stereo: bool = False,
max_value: float = 30e6,
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
This is the inverse of image_from_spectrogram, except for discretization error from
quantizing to uint8.
Args:
image: (frequency, time, channels)
power: The power curve applied to the spectrogram
stereo: Whether the spectrogram encodes stereo data
max_value: The max value of the original spectrogram. In practice doesn't matter.
Returns:
spectrogram: (channels, frequency, time)
"""
# Flip Y
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
# Munge channels into a numpy array of (channels, frequency, time)
data = np.array(image).transpose(2, 0, 1)
if stereo:
# Take the G and B channels as done in image_from_spectrogram
data = data[[1, 2], :, :]
else:
data = data[0:1, :, :]
# Convert to floats
data = data.astype(np.float32)
# Invert
data = 255 - data
# Rescale to 0-1
data = data / 255
# Reverse the power curve
data = np.power(data, 1 / power)
# Rescale to max value
data = data * max_value
return data
def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]:
"""
Get the EXIF data from a PIL image as a dict.
"""
exif = pil_image.getexif()
if exif is None or len(exif) == 0:
return {}
return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()}

View File

@ -0,0 +1,48 @@
import warnings
import numpy as np
import torch
def check_device(device: str, backup: str = "cpu") -> str:
"""
Check that the device is valid and available. If not,
"""
cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
if cuda_not_found or mps_not_found:
warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
return backup
return device
def slerp(
t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995
) -> torch.Tensor:
"""
Helper function to spherically interpolate two arrays v1 v2.
"""
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > dot_threshold:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2

0
test/__init__.py Normal file
View File

View File

@ -0,0 +1,99 @@
import typing as T
import numpy as np
from PIL import Image
from riffusion.cli import audio_to_image
from riffusion.spectrogram_params import SpectrogramParams
from .test_case import TestCase
class AudioToImageTest(TestCase):
"""
Test riffusion.cli audio-to-image
"""
@classmethod
def default_params(cls) -> T.Dict:
return dict(
step_size_ms=10,
num_frequencies=512,
# TODO(hayk): Change these to [20, 20000] once a model is updated
min_frequency=0,
max_frequency=10000,
stereo=False,
device=cls.DEVICE,
)
def test_audio_to_image(self) -> None:
"""
Test audio-to-image with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_stereo(self) -> None:
"""
Test audio-to-image with stereo=True.
"""
params = self.default_params()
params["stereo"] = True
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir("audio_to_image_")
if params["stereo"]:
stem = f"{audio_path.stem}_stereo"
else:
stem = audio_path.stem
image_path = output_dir / f"{stem}.png"
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
# Check that the image exists
self.assertTrue(image_path.exists())
pil_image = Image.open(image_path)
# Check the image mode
self.assertEqual(pil_image.mode, "RGB")
# Check the image dimensions
duration_ms = 5678
self.assertTrue(str(duration_ms) in audio_path.name)
expected_image_width = round(duration_ms / params["step_size_ms"])
self.assertEqual(pil_image.width, expected_image_width)
self.assertEqual(pil_image.height, params["num_frequencies"])
# Get channels as numpy arrays
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
self.assertEqual(len(channels), 3)
if params["stereo"]:
# Check that the first channel is zero
self.assertTrue(np.all(channels[0] == 0))
else:
# Check that all channels are the same
self.assertTrue(np.all(channels[0] == channels[1]))
self.assertTrue(np.all(channels[0] == channels[2]))
# Check that the image has exif data
exif = pil_image.getexif()
self.assertIsNotNone(exif)
params_from_exif = SpectrogramParams.from_exif(exif)
expected_params = SpectrogramParams(
stereo=params["stereo"],
step_size_ms=params["step_size_ms"],
num_frequencies=params["num_frequencies"],
max_frequency=params["max_frequency"],
)
self.assertTrue(params_from_exif == expected_params)

View File

@ -0,0 +1,71 @@
from pathlib import Path
import pydub
from riffusion.cli import image_to_audio
from .test_case import TestCase
class ImageToAudioTest(TestCase):
"""
Test riffusion.cli image-to-audio
"""
def test_image_to_audio_mono(self) -> None:
self.helper_image_to_audio(
song_dir=self.TEST_DATA_PATH / "tired_traveler",
clip_name="clip_2_start_103694_ms_duration_5678_ms",
stereo=False,
)
def test_image_to_audio_stereo(self) -> None:
self.helper_image_to_audio(
song_dir=self.TEST_DATA_PATH / "tired_traveler",
clip_name="clip_2_start_103694_ms_duration_5678_ms",
stereo=True,
)
def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
if stereo:
image_stem = clip_name + "_stereo"
else:
image_stem = clip_name
image_path = song_dir / "images" / f"{image_stem}.png"
output_dir = self.get_tmp_dir("image_to_audio_")
audio_path = output_dir / f"{image_path.stem}.wav"
image_to_audio(
image=str(image_path),
audio=str(audio_path),
device=self.DEVICE,
)
# Check that the audio exists
self.assertTrue(audio_path.exists())
# Load the reconstructed audio and the original clip
segment = pydub.AudioSegment.from_file(str(audio_path))
expected_segment = pydub.AudioSegment.from_file(
str(song_dir / "clips" / f"{clip_name}.wav")
)
# Check sample rate
self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
# Check duration
actual_duration_ms = round(segment.duration_seconds * 1000)
expected_duration_ms = round(expected_segment.duration_seconds * 1000)
self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
# Check the number of channels
self.assertEqual(expected_segment.channels, 2)
if stereo:
self.assertEqual(segment.channels, 2)
else:
self.assertEqual(segment.channels, 1)
if __name__ == "__main__":
TestCase.main()

65
test/image_util_test.py Normal file
View File

@ -0,0 +1,65 @@
import numpy as np
import pydub
from riffusion.util import image_util
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from .test_case import TestCase
class ImageUtilTest(TestCase):
"""
Test riffusion.util.image_util
"""
def test_spectrogram_to_image_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono
segment = segment.set_channels(1)
# Compute a spectrogram with default params
params = SpectrogramParams(sample_rate=segment.frame_rate)
converter = SpectrogramConverter(params=params, device=self.DEVICE)
spectrogram = converter.spectrogram_from_audio(segment)
# Compute the image from the spectrogram
image = image_util.image_from_spectrogram(
spectrogram=spectrogram,
power=params.power_for_image,
)
# Save the max value
max_value = np.max(spectrogram)
# Compute the spectrogram from the image
spectrogram_reversed = image_util.spectrogram_from_image(
image=image,
max_value=max_value,
power=params.power_for_image,
stereo=params.stereo,
)
# Check the shapes
self.assertEqual(spectrogram.shape, spectrogram_reversed.shape)
# Check the max values
self.assertEqual(np.max(spectrogram), np.max(spectrogram_reversed))
# Check the median values
self.assertTrue(
np.allclose(np.median(spectrogram), np.median(spectrogram_reversed), rtol=0.05)
)
# Make sure all values are somewhat similar, but allow for discretization error
# TODO(hayk): Investigate error more closely
self.assertTrue(np.allclose(spectrogram, spectrogram_reversed, rtol=0.15))

24
test/linter_test.py Normal file
View File

@ -0,0 +1,24 @@
from pathlib import Path
import subprocess
from .test_case import TestCase
class LinterTest(TestCase):
"""
Test that ruff, black, and mypy run cleanly.
"""
HOME = Path(__file__).parent.parent
def test_ruff(self) -> None:
code = subprocess.check_call(["ruff", str(self.HOME)])
self.assertEqual(code, 0)
def test_black(self) -> None:
code = subprocess.check_call(["black", "--check", str(self.HOME)])
self.assertEqual(code, 0)
def test_mypy(self) -> None:
code = subprocess.check_call(["mypy", str(self.HOME)])
self.assertEqual(code, 0)

32
test/print_exif_test.py Normal file
View File

@ -0,0 +1,32 @@
import contextlib
import io
from riffusion.cli import print_exif
from .test_case import TestCase
class PrintExifTest(TestCase):
"""
Test riffusion.cli print-exif
"""
def test_print_exif(self) -> None:
"""
Test print-exif.
"""
image_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "images"
/ "clip_2_start_103694_ms_duration_5678_ms.png"
)
# Redirect stdout
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
print_exif(image=str(image_path))
# Check that a couple of values are printed
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())

88
test/sample_clips_test.py Normal file
View File

@ -0,0 +1,88 @@
import typing as T
import pydub
from riffusion.cli import sample_clips
from .test_case import TestCase
class SampleClipsTest(TestCase):
"""
Test riffusion.cli sample-clips
"""
@staticmethod
def default_params() -> T.Dict:
return dict(
num_clips=3,
duration_ms=5678,
mono=False,
extension="wav",
seed=42,
)
def test_sample_clips(self) -> None:
"""
Test sample-clips with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_mono(self) -> None:
"""
Test sample-clips with mono=True.
"""
params = self.default_params()
params["mono"] = True
params["num_clips"] = 1
self.helper_test_with_params(params)
def test_mp3(self) -> None:
"""
Test sample-clips with extension=mp3.
"""
if pydub.AudioSegment.converter is None:
self.skipTest("skipping, ffmpeg not found")
params = self.default_params()
params["extension"] = "mp3"
params["num_clips"] = 1
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
"""
Test sample-clips with the given params.
"""
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
output_dir = self.get_tmp_dir("sample_clips_")
sample_clips(
audio=str(audio_path),
output_dir=str(output_dir),
**params,
)
# For each file in output dir
counter = 0
for clip_path in output_dir.iterdir():
# Check that it has the right extension
self.assertEqual(clip_path.suffix, f".{params['extension']}")
# Check that it has the right duration
segment = pydub.AudioSegment.from_file(clip_path)
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
# Check that it has the right number of channels
if params["mono"]:
self.assertEqual(segment.channels, 1)
else:
self.assertEqual(segment.channels, 2)
counter += 1
self.assertEqual(counter, params["num_clips"])
if __name__ == "__main__":
TestCase.main()

View File

@ -0,0 +1,86 @@
import dataclasses
import typing as T
import pydub
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import fft_util
from .test_case import TestCase
class SpectrogramConverterTest(TestCase):
"""
Test going from audio to spectrogram to audio, without converting to
an image, to check quality loss of the reconstruction.
This test allows comparing multiple sets of spectrogram params by listening to output audio
and by plotting their FFTs.
"""
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
def test_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono if desired
use_stereo = False
if use_stereo:
assert segment.channels == 2
else:
segment = segment.set_channels(1)
# Define named sets of parameters
param_sets: T.Dict[str, SpectrogramParams] = {}
param_sets["default"] = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=use_stereo,
step_size_ms=10,
min_frequency=20,
max_frequency=20000,
num_frequencies=512,
)
if self.DEBUG:
param_sets["freq_0_to_10k"] = dataclasses.replace(
param_sets["default"],
min_frequency=0,
max_frequency=10000,
)
segments: T.Dict[str, pydub.AudioSegment] = {
"original": segment,
}
for name, params in param_sets.items():
converter = SpectrogramConverter(params=params, device=self.DEVICE)
spectrogram = converter.spectrogram_from_audio(segment)
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
# Save segments to disk
for name, segment in segments.items():
audio_out = output_dir / f"{name}.wav"
segment.export(audio_out, format="wav")
print(f"Saved {audio_out}")
# Check params
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
self.assertEqual(segments["original"].channels, segments["default"].channels)
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
# If debugging, load up a browser tab plotting the FFTs
if self.DEBUG:
fft_util.plot_ffts(segments)

View File

@ -0,0 +1,97 @@
import dataclasses
import typing as T
from PIL import Image
import pydub
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import fft_util
from .test_case import TestCase
class SpectrogramImageConverterTest(TestCase):
"""
Test going from audio to spectrogram images to audio, testing the quality loss of the
end-to-end pipeline.
This test allows comparing multiple sets of spectrogram params by listening to output audio
and by plotting their FFTs.
See spectrogram_converter_test.py for a similar test that does not convert to images.
"""
def test_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_")
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono if desired
use_stereo = False
if use_stereo:
assert segment.channels == 2
else:
segment = segment.set_channels(1)
# Define named sets of parameters
param_sets: T.Dict[str, SpectrogramParams] = {}
param_sets["default"] = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=use_stereo,
step_size_ms=10,
min_frequency=20,
max_frequency=20000,
num_frequencies=512,
)
if self.DEBUG:
param_sets["freq_0_to_10k"] = dataclasses.replace(
param_sets["default"],
min_frequency=0,
max_frequency=10000,
)
segments: T.Dict[str, pydub.AudioSegment] = {
"original": segment,
}
images: T.Dict[str, Image.Image] = {}
for name, params in param_sets.items():
converter = SpectrogramImageConverter(params=params, device=self.DEVICE)
images[name] = converter.spectrogram_image_from_audio(segment)
segments[name] = converter.audio_from_spectrogram_image(
image=images[name],
apply_filters=True,
)
# Save images to disk
for name, image in images.items():
image_out = output_dir / f"{name}.png"
image.save(image_out, exif=image.getexif(), format="PNG")
print(f"Saved {image_out}")
# Save segments to disk
for name, segment in segments.items():
audio_out = output_dir / f"{name}.wav"
segment.export(audio_out, format="wav")
print(f"Saved {audio_out}")
# Check params
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
self.assertEqual(segments["original"].channels, segments["default"].channels)
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
# If debugging, load up a browser tab plotting the FFTs
if self.DEBUG:
fft_util.plot_ffts(segments)

48
test/test_case.py Normal file
View File

@ -0,0 +1,48 @@
import os
from pathlib import Path
import shutil
import tempfile
import typing as T
import warnings
import unittest
class TestCase(unittest.TestCase):
"""
Base class for tests.
"""
# Where checked-in test data is stored
TEST_DATA_PATH = Path(__file__).resolve().parent / "test_data"
# Whether to run tests in debug mode (e.g. don't clean up temporary directories, show plots)
DEBUG = bool(os.environ.get("RIFFUSION_TEST_DEBUG"))
# Which torch device to use for tests
DEVICE = os.environ.get("RIFFUSION_TEST_DEVICE", "cuda")
@staticmethod
def main(*args: T.Any, **kwargs: T.Any) -> None:
"""
Run the tests.
"""
unittest.main(*args, **kwargs)
@classmethod
def setUpClass(cls):
warnings.filterwarnings("ignore", category=ResourceWarning)
def get_tmp_dir(self, prefix: str) -> Path:
"""
Create a temporary directory.
"""
tmp_dir = tempfile.mkdtemp(prefix=prefix)
# Clean up the temporary directory if not debugging
if not self.DEBUG:
self.addCleanup(lambda: shutil.rmtree(tmp_dir, ignore_errors=True))
dir_path = Path(tmp_dir)
assert dir_path.is_dir()
return dir_path

7
test/test_data/README.md Normal file
View File

@ -0,0 +1,7 @@
# Test Data
### tired_traveler
* Song: Tired traveler on the way to home
* Artist: Andrew Codeman
* Source: https://freemusicarchive.org/

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

Binary file not shown.