Merge pull request #36 from riffusion/hayk.mart/revup/main/clean_rewrite
Rewrite the codebase to be high quality
This commit is contained in:
commit
d0fe85a4db
|
@ -6,6 +6,9 @@ __pycache__/
|
|||
# C extensions
|
||||
*.so
|
||||
|
||||
# VSCode
|
||||
.vscode
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
@ -27,6 +30,9 @@ share/python-wheels/
|
|||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# OSX cruft
|
||||
.DS_Store
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
52
README.md
52
README.md
|
@ -32,14 +32,16 @@ python -m pip install -r requirements.txt
|
|||
|
||||
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
|
||||
|
||||
You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 – you'll need ffmpeg or libav.
|
||||
|
||||
Guides:
|
||||
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
|
||||
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
|
||||
|
||||
## Run
|
||||
## Run the model server
|
||||
Start the Flask server:
|
||||
```
|
||||
python -m riffusion.server --port 3013 --host 127.0.0.1
|
||||
python -m riffusion.server --host 127.0.0.1 --port 3013
|
||||
```
|
||||
|
||||
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
|
||||
|
@ -77,6 +79,52 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
|
|||
}
|
||||
```
|
||||
|
||||
Use the `--device` argument to specify the torch device to use.
|
||||
|
||||
`cuda` is recommended.
|
||||
|
||||
`cpu` works but is quite slow.
|
||||
|
||||
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
|
||||
|
||||
## Test
|
||||
Tests live in the `test/` directory and are implemented with `unittest`.
|
||||
|
||||
To run all tests:
|
||||
```
|
||||
python -m unittest test/*_test.py
|
||||
```
|
||||
|
||||
To run a single test:
|
||||
```
|
||||
python -m unittest test.audio_to_image_test
|
||||
```
|
||||
|
||||
To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
|
||||
```
|
||||
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
|
||||
```
|
||||
|
||||
To run a single test case:
|
||||
```
|
||||
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
|
||||
```
|
||||
|
||||
To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with
|
||||
`cpu`, `cuda`, and `mps` backends.
|
||||
|
||||
## Development
|
||||
Install additional packages for dev with `pip install -r dev_requirements.txt`.
|
||||
|
||||
* Linter: `ruff`
|
||||
* Formatter: `black`
|
||||
* Type checker: `mypy`
|
||||
|
||||
These are configured in `pyproject.toml`.
|
||||
|
||||
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
|
||||
|
||||
## Citation
|
||||
|
||||
If you build on this work, please cite it as follows:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
black
|
||||
ipdb
|
||||
isort
|
||||
mypy
|
||||
pylint
|
||||
ruff
|
||||
types-Flask-Cors
|
||||
types-Pillow
|
||||
types-requests
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# Integrations
|
||||
|
||||
This package contains integrations of Riffusion into third party apps and deployments.
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
This file can be used to build a Truss for deployment with Baseten.
|
||||
If used, it should be renamed to model.py and placed alongside the other
|
||||
files from /riffusion in the standard /model directory of the Truss.
|
||||
|
||||
For more on the Truss file format, see https://truss.baseten.co/
|
||||
"""
|
||||
|
||||
import typing as T
|
||||
|
||||
import torch
|
||||
import dacite
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.server import compute_request
|
||||
from riffusion.datatypes import InferenceInput
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
Baseten Truss model class for riffusion.
|
||||
|
||||
See: https://truss.baseten.co/reference/structure#model.py
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
self._config = kwargs["config"]
|
||||
self._pipeline = None
|
||||
self._vae = None
|
||||
|
||||
self.checkpoint_name = "riffusion/riffusion-model-v1"
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Load the model. Guaranteed to be called before `predict`.
|
||||
"""
|
||||
self._pipeline = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=self.checkpoint_name,
|
||||
use_traced_unet=True,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
def preprocess(self, request: T.Dict) -> T.Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
||||
These might be feature transformations that are tightly coupled to the model.
|
||||
"""
|
||||
return request
|
||||
|
||||
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
|
||||
"""
|
||||
This is the main function that is called.
|
||||
"""
|
||||
assert self._pipeline is not None, "Model pipeline not loaded"
|
||||
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, request)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
||||
response = compute_request(
|
||||
inputs=inputs,
|
||||
pipeline=self._pipeline,
|
||||
seed_images_dir=self._seed_images_dir,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def postprocess(self, request: T.Dict) -> T.Dict:
|
||||
"""
|
||||
Incorporate post-processing required by the model if desired here.
|
||||
"""
|
||||
return request
|
|
@ -0,0 +1,87 @@
|
|||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
# Which rules to run
|
||||
select = [
|
||||
# Pyflakes
|
||||
"F",
|
||||
# Pycodestyle
|
||||
"E",
|
||||
"W",
|
||||
# isort
|
||||
# "I001"
|
||||
]
|
||||
|
||||
ignore = []
|
||||
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
per-file-ignores = {}
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
# Assume Python 3.10.
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
# Unlike Flake8, default to a complexity level of 10.
|
||||
max-complexity = 10
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "argh.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "diffusers.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "plotly.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "pydub.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "scipy.fft.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "scipy.io.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "torchaudio.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
ignore_missing_imports = true
|
|
@ -6,9 +6,11 @@ flask
|
|||
flask_cors
|
||||
numpy
|
||||
pillow
|
||||
plotly
|
||||
pydub
|
||||
scipy
|
||||
soundfile
|
||||
streamlit
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
|
|
|
@ -1,213 +0,0 @@
|
|||
"""
|
||||
Audio processing tools to convert between spectrogram images and waveforms.
|
||||
"""
|
||||
import io
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pydub
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
|
||||
"""
|
||||
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
|
||||
"""
|
||||
|
||||
max_volume = 50
|
||||
power_for_image = 0.25
|
||||
Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
|
||||
|
||||
sample_rate = 44100 # [Hz]
|
||||
clip_duration_ms = 5000 # [ms]
|
||||
|
||||
bins_per_image = 512
|
||||
n_mels = 512
|
||||
|
||||
# FFT parameters
|
||||
window_duration_ms = 100 # [ms]
|
||||
padded_duration_ms = 400 # [ms]
|
||||
step_size_ms = 10 # [ms]
|
||||
|
||||
# Derived parameters
|
||||
num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
|
||||
n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
|
||||
hop_length = int(step_size_ms / 1000.0 * sample_rate)
|
||||
win_length = int(window_duration_ms / 1000.0 * sample_rate)
|
||||
|
||||
samples = waveform_from_spectrogram(
|
||||
Sxx=Sxx,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
num_samples=num_samples,
|
||||
sample_rate=sample_rate,
|
||||
mel_scale=True,
|
||||
n_mels=n_mels,
|
||||
max_mel_iters=200,
|
||||
num_griffin_lim_iters=32,
|
||||
)
|
||||
|
||||
wav_bytes = io.BytesIO()
|
||||
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
|
||||
wav_bytes.seek(0)
|
||||
|
||||
duration_s = float(len(samples)) / sample_rate
|
||||
|
||||
return wav_bytes, duration_s
|
||||
|
||||
|
||||
def spectrogram_from_image(
|
||||
image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram magnitude array from a spectrogram image.
|
||||
|
||||
TODO(hayk): Add image_from_spectrogram and call this out as the reverse.
|
||||
"""
|
||||
# Convert to a numpy array of floats
|
||||
data = np.array(image).astype(np.float32)
|
||||
|
||||
# Flip Y take a single channel
|
||||
data = data[::-1, :, 0]
|
||||
|
||||
# Invert
|
||||
data = 255 - data
|
||||
|
||||
# Rescale to max volume
|
||||
data = data * max_volume / 255
|
||||
|
||||
# Reverse the power curve
|
||||
data = np.power(data, 1 / power_for_image)
|
||||
|
||||
return data
|
||||
|
||||
def image_from_spectrogram(
|
||||
spectrogram: np.ndarray, max_volume: float = 50, power_for_image: float = 0.25
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Compute a spectrogram image from a spectrogram magnitude array.
|
||||
"""
|
||||
# Apply the power curve
|
||||
data = np.power(spectrogram, power_for_image)
|
||||
|
||||
# Rescale to 0-1
|
||||
data = data / np.max(data)
|
||||
|
||||
# Rescale to 0-255
|
||||
data = data * 255
|
||||
|
||||
# Invert
|
||||
data = 255 - data
|
||||
|
||||
# Convert to a PIL image
|
||||
image = Image.fromarray(data.astype(np.uint8))
|
||||
|
||||
# Flip Y
|
||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
def spectrogram_from_waveform(
|
||||
waveform: np.ndarray,
|
||||
sample_rate: int,
|
||||
n_fft: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
mel_scale: bool = True,
|
||||
n_mels: int = 512,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram from a waveform.
|
||||
"""
|
||||
|
||||
spectrogram_func = torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
power=None,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
)
|
||||
|
||||
waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
|
||||
Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]
|
||||
|
||||
Sxx_mag = np.abs(Sxx_complex)
|
||||
|
||||
if mel_scale:
|
||||
mel_scaler = torchaudio.transforms.MelScale(
|
||||
n_mels=n_mels,
|
||||
sample_rate=sample_rate,
|
||||
f_min=0,
|
||||
f_max=10000,
|
||||
n_stft=n_fft // 2 + 1,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
)
|
||||
|
||||
Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()
|
||||
|
||||
return Sxx_mag
|
||||
|
||||
|
||||
def waveform_from_spectrogram(
|
||||
Sxx: np.ndarray,
|
||||
n_fft: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
num_samples: int,
|
||||
sample_rate: int,
|
||||
mel_scale: bool = True,
|
||||
n_mels: int = 512,
|
||||
max_mel_iters: int = 200,
|
||||
num_griffin_lim_iters: int = 32,
|
||||
device: str = "cuda:0",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reconstruct a waveform from a spectrogram.
|
||||
|
||||
This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
|
||||
to approximate the phase.
|
||||
"""
|
||||
Sxx_torch = torch.from_numpy(Sxx).to(device)
|
||||
|
||||
# TODO(hayk): Make this a class that caches the two things
|
||||
|
||||
if mel_scale:
|
||||
mel_inv_scaler = torchaudio.transforms.InverseMelScale(
|
||||
n_mels=n_mels,
|
||||
sample_rate=sample_rate,
|
||||
f_min=0,
|
||||
f_max=10000,
|
||||
n_stft=n_fft // 2 + 1,
|
||||
norm=None,
|
||||
mel_scale="htk",
|
||||
max_iter=max_mel_iters,
|
||||
).to(device)
|
||||
|
||||
Sxx_torch = mel_inv_scaler(Sxx_torch)
|
||||
|
||||
griffin_lim = torchaudio.transforms.GriffinLim(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
power=1.0,
|
||||
n_iter=num_griffin_lim_iters,
|
||||
).to(device)
|
||||
|
||||
waveform = griffin_lim(Sxx_torch).cpu().numpy()
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
def mp3_bytes_from_wav_bytes(wav_bytes: io.BytesIO) -> io.BytesIO:
|
||||
mp3_bytes = io.BytesIO()
|
||||
sound = pydub.AudioSegment.from_wav(wav_bytes)
|
||||
sound.export(mp3_bytes, format="mp3")
|
||||
mp3_bytes.seek(0)
|
||||
return mp3_bytes
|
|
@ -1,175 +0,0 @@
|
|||
"""
|
||||
This file can be used to build a Truss for deployment with Baseten.
|
||||
If used, it should be renamed to model.py and placed alongside the other
|
||||
files from /riffusion in the standard /model directory of the Truss.
|
||||
|
||||
For more on the Truss file format, see https://truss.baseten.co/
|
||||
"""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import json
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import dacite
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
|
||||
from .datatypes import InferenceInput, InferenceOutput
|
||||
from .riffusion_pipeline import RiffusionPipeline
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
self._config = kwargs["config"]
|
||||
self._model = None
|
||||
self._vae = None
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download(
|
||||
"riffusion/riffusion-model-v1", allow_patterns="*.png"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
# Load Riffusion model here and assign to self._model.
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if torch.cuda.is_available() == False:
|
||||
# Use only if you don't have a GPU with fp16 support
|
||||
self._model = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
else:
|
||||
# Model loading the model with fp16. This will fail if ran without a GPU with fp16 support
|
||||
pipe = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
# Deliberately not implementing channels_Last as it resulted in slower inference pipeline
|
||||
# pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
# Use traced unet from hf hub
|
||||
unet_file = hf_hub_download(
|
||||
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = pipe.unet.in_channels
|
||||
self.device = pipe.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
pipe.unet = TracedUNet()
|
||||
|
||||
self._model = pipe
|
||||
|
||||
def preprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
||||
These might be feature transformations that are tightly coupled to the model.
|
||||
"""
|
||||
return request
|
||||
|
||||
def postprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate post-processing required by the model if desired here.
|
||||
"""
|
||||
return request
|
||||
|
||||
def predict(self, request: Dict) -> Dict[str, List]:
|
||||
"""
|
||||
This is the main function that is called.
|
||||
"""
|
||||
# Example request:
|
||||
# {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}}
|
||||
|
||||
# Parse an InferenceInput dataclass from the payload
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, request)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
# logging.info(json_data)
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
# logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
||||
response = self.compute(inputs)
|
||||
|
||||
return response
|
||||
|
||||
def compute(self, inputs: InferenceInput) -> str:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png")
|
||||
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||
|
||||
# Compute the output as base64 encoded strings
|
||||
image_bytes = self.image_bytes_from_image(image, mode="JPEG")
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + self.base64_encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
return json.dumps(dataclasses.asdict(output))
|
||||
|
||||
def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
"""
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, mode)
|
||||
image_bytes.seek(0)
|
||||
return image_bytes
|
||||
|
||||
def base64_encode(self, buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
Command line tools for riffusion.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import argh
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
|
||||
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
|
||||
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
|
||||
def audio_to_image(
|
||||
*,
|
||||
audio: str,
|
||||
image: str,
|
||||
step_size_ms: int = 10,
|
||||
num_frequencies: int = 512,
|
||||
min_frequency: int = 0,
|
||||
max_frequency: int = 10000,
|
||||
window_duration_ms: int = 100,
|
||||
padded_duration_ms: int = 400,
|
||||
power_for_image: float = 0.25,
|
||||
stereo: bool = False,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""
|
||||
Compute a spectrogram image from a waveform.
|
||||
"""
|
||||
segment = pydub.AudioSegment.from_file(audio)
|
||||
|
||||
params = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=stereo,
|
||||
window_duration_ms=window_duration_ms,
|
||||
padded_duration_ms=padded_duration_ms,
|
||||
step_size_ms=step_size_ms,
|
||||
min_frequency=min_frequency,
|
||||
max_frequency=max_frequency,
|
||||
num_frequencies=num_frequencies,
|
||||
power_for_image=power_for_image,
|
||||
)
|
||||
|
||||
converter = SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
pil_image = converter.spectrogram_image_from_audio(segment)
|
||||
|
||||
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
|
||||
print(f"Wrote {image}")
|
||||
|
||||
|
||||
def print_exif(*, image: str) -> None:
|
||||
"""
|
||||
Print the params of a spectrogram image as saved in the exif data.
|
||||
"""
|
||||
pil_image = Image.open(image)
|
||||
exif_data = image_util.exif_from_image(pil_image)
|
||||
|
||||
for name, value in exif_data.items():
|
||||
print(f"{name:<20} = {value:>15}")
|
||||
|
||||
|
||||
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
|
||||
"""
|
||||
Reconstruct an audio clip from a spectrogram image.
|
||||
"""
|
||||
pil_image = Image.open(image)
|
||||
|
||||
# Get parameters from image exif
|
||||
img_exif = pil_image.getexif()
|
||||
assert img_exif is not None
|
||||
|
||||
try:
|
||||
params = SpectrogramParams.from_exif(exif=img_exif)
|
||||
except KeyError:
|
||||
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
|
||||
params = SpectrogramParams()
|
||||
|
||||
converter = SpectrogramImageConverter(params=params, device=device)
|
||||
segment = converter.audio_from_spectrogram_image(pil_image)
|
||||
|
||||
extension = Path(audio).suffix[1:]
|
||||
segment.export(audio, format=extension)
|
||||
|
||||
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
|
||||
|
||||
|
||||
def sample_clips(
|
||||
*,
|
||||
audio: str,
|
||||
output_dir: str,
|
||||
num_clips: int = 1,
|
||||
duration_ms: int = 5000,
|
||||
mono: bool = False,
|
||||
extension: str = "wav",
|
||||
seed: int = -1,
|
||||
):
|
||||
"""
|
||||
Slice an audio file into clips of the given duration.
|
||||
"""
|
||||
if seed >= 0:
|
||||
np.random.seed(seed)
|
||||
|
||||
segment = pydub.AudioSegment.from_file(audio)
|
||||
|
||||
if mono:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
output_dir_path = Path(output_dir)
|
||||
if not output_dir_path.exists():
|
||||
output_dir_path.mkdir(parents=True)
|
||||
|
||||
# TODO(hayk): Might be a lot easier with pydub
|
||||
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
|
||||
|
||||
segment_duration_ms = int(segment.duration_seconds * 1000)
|
||||
for i in range(num_clips):
|
||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
||||
clip_path = output_dir_path / clip_name
|
||||
clip.export(clip_path, format=extension)
|
||||
print(f"Wrote {clip_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argh.dispatch_commands(
|
||||
[
|
||||
audio_to_image,
|
||||
image_to_audio,
|
||||
sample_clips,
|
||||
print_exif,
|
||||
]
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Data model for the riffusion API.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
|
@ -58,6 +59,7 @@ class InferenceOutput:
|
|||
"""
|
||||
Response from the model inference server.
|
||||
"""
|
||||
|
||||
# base64 encoded spectrogram image as a JPEG
|
||||
image: str
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# external
|
||||
|
||||
This package contains scripts and tools from external sources.
|
|
@ -5,10 +5,13 @@ This code is taken from the diffusers community pipeline:
|
|||
|
||||
License: Apache 2.0
|
||||
"""
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as T
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
@ -123,7 +126,7 @@ def parse_prompt_attention(text):
|
|||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
||||
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
No padding, starting or ending token is included.
|
||||
|
@ -192,8 +195,8 @@ def get_unweighted_text_embeddings(
|
|||
pipe: StableDiffusionPipeline,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
no_boseos_middle: T.Optional[bool] = True,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
|
@ -232,14 +235,14 @@ def get_unweighted_text_embeddings(
|
|||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
prompt: Union[str, List[str]],
|
||||
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
prompt: T.Union[str, T.List[str]],
|
||||
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||
max_embeddings_multiples: T.Optional[int] = 3,
|
||||
no_boseos_middle: T.Optional[bool] = False,
|
||||
skip_parsing: T.Optional[bool] = False,
|
||||
skip_weighting: T.Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||||
|
@ -248,9 +251,9 @@ def get_weighted_text_embeddings(
|
|||
Args:
|
||||
pipe (`StableDiffusionPipeline`):
|
||||
Pipe to provide access to the tokenizer and the text encoder.
|
||||
prompt (`str` or `List[str]`):
|
||||
prompt (`str` or `T.List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
uncond_prompt (`str` or `List[str]`):
|
||||
uncond_prompt (`str` or `T.List[str]`):
|
||||
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
||||
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
|
@ -269,8 +272,6 @@ def get_weighted_text_embeddings(
|
|||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
||||
print(f"tokens: {prompt_tokens}")
|
||||
print(f"weights: {prompt_weights}")
|
||||
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
|
@ -1,13 +1,15 @@
|
|||
"""
|
||||
Riffusion inference pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
|
@ -15,9 +17,12 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import logging
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .datatypes import InferenceInput
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
|
||||
from riffusion.util import torch_util
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
@ -56,8 +61,110 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(
|
||||
cls,
|
||||
checkpoint: str,
|
||||
use_traced_unet: bool = True,
|
||||
channels_last: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: str = "cuda",
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
|
||||
Args:
|
||||
checkpoint: Model checkpoint on disk in diffusers format
|
||||
use_traced_unet: Whether to use the traced unet for speedups
|
||||
device: Device to load the model on
|
||||
channels_last: Whether to use channels_last memory format
|
||||
"""
|
||||
device = torch_util.check_device(device)
|
||||
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
pipeline = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
# Optionally attempt to use less memory
|
||||
low_cpu_mem_usage=False,
|
||||
).to(device)
|
||||
|
||||
if channels_last:
|
||||
pipeline.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
# Optionally load a traced unet
|
||||
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
|
||||
traced_unet = cls.load_traced_unet(
|
||||
checkpoint=checkpoint,
|
||||
subfolder="unet_traced",
|
||||
filename="unet_traced.pt",
|
||||
in_channels=pipeline.unet.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if traced_unet is not None:
|
||||
pipeline.unet = traced_unet
|
||||
|
||||
model = pipeline.to(device)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def load_traced_unet(
|
||||
checkpoint: str,
|
||||
subfolder: str,
|
||||
filename: str,
|
||||
in_channels: int,
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
) -> T.Optional[torch.nn.Module]:
|
||||
"""
|
||||
Load a traced unet from the huggingface hub. This can improve performance.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print("WARNING: Traced UNet only available for CUDA, skipping")
|
||||
return None
|
||||
|
||||
# Download and load the traced unet
|
||||
unet_file = hf_hub_download(
|
||||
checkpoint,
|
||||
subfolder=subfolder,
|
||||
filename=filename,
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
# Wrap it in a torch module
|
||||
class TracedUNet(torch.nn.Module):
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = device
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return self.UNet2DConditionOutput(sample=sample)
|
||||
|
||||
return TracedUNet()
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
return str(self.vae.device)
|
||||
|
||||
@functools.lru_cache()
|
||||
def embed_text(self, text):
|
||||
def embed_text(self, text) -> torch.FloatTensor:
|
||||
"""
|
||||
Takes in text and turns it into text embeddings.
|
||||
"""
|
||||
|
@ -73,12 +180,10 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
return embed
|
||||
|
||||
@functools.lru_cache()
|
||||
def embed_text_weighted(self, text):
|
||||
def embed_text_weighted(self, text) -> torch.FloatTensor:
|
||||
"""
|
||||
Get text embedding with weights.
|
||||
"""
|
||||
from .prompt_weighting import get_weighted_text_embeddings
|
||||
|
||||
return get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=text,
|
||||
|
@ -93,10 +198,10 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
def riffuse(
|
||||
self,
|
||||
inputs: InferenceInput,
|
||||
init_image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image = None,
|
||||
init_image: Image.Image,
|
||||
mask_image: T.Optional[Image.Image] = None,
|
||||
use_reweighting: bool = True,
|
||||
) -> PIL.Image.Image:
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Runs inference using interpolation with both img2img and text conditioning.
|
||||
|
||||
|
@ -113,6 +218,12 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
end = inputs.end
|
||||
|
||||
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
|
||||
|
||||
# TODO(hayk): Always generate the seed on CPU?
|
||||
if self.device.lower().startswith("mps"):
|
||||
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
|
||||
else:
|
||||
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
|
||||
|
||||
|
@ -123,25 +234,31 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
else:
|
||||
embed_start = self.embed_text(start.prompt)
|
||||
embed_end = self.embed_text(end.prompt)
|
||||
text_embedding = torch.lerp(embed_start, embed_end, alpha)
|
||||
|
||||
text_embedding = embed_start + alpha * (embed_end - embed_start)
|
||||
|
||||
# Image latents
|
||||
init_image = preprocess_image(init_image)
|
||||
init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype)
|
||||
init_image_torch = preprocess_image(init_image).to(
|
||||
device=self.device, dtype=embed_start.dtype
|
||||
)
|
||||
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
|
||||
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
|
||||
# result is so close no matter the seed that it doesn't really add variety.
|
||||
if self.device.lower().startswith("mps"):
|
||||
generator = torch.Generator(device="cpu").manual_seed(start.seed)
|
||||
else:
|
||||
generator = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Prepare mask latent
|
||||
mask: T.Optional[torch.Tensor] = None
|
||||
if mask_image:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor)
|
||||
mask = mask_image.to(device=self.device, dtype=embed_start.dtype)
|
||||
else:
|
||||
mask = None
|
||||
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
|
||||
device=self.device, dtype=embed_start.dtype
|
||||
)
|
||||
|
||||
outputs = self.interpolate_img2img(
|
||||
text_embeddings=text_embedding,
|
||||
|
@ -161,18 +278,18 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
@torch.no_grad()
|
||||
def interpolate_img2img(
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
init_latents: torch.FloatTensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
init_latents: torch.Tensor,
|
||||
generator_a: torch.Generator,
|
||||
generator_b: torch.Generator,
|
||||
interpolate_alpha: float,
|
||||
mask: T.Optional[torch.FloatTensor] = None,
|
||||
mask: T.Optional[torch.Tensor] = None,
|
||||
strength_a: float = 0.8,
|
||||
strength_b: float = 0.8,
|
||||
num_inference_steps: T.Optional[int] = 50,
|
||||
guidance_scale: T.Optional[float] = 7.5,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||
num_images_per_prompt: T.Optional[int] = 1,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: T.Optional[float] = 0.0,
|
||||
output_type: T.Optional[str] = "pil",
|
||||
**kwargs,
|
||||
|
@ -198,11 +315,6 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
|
@ -251,11 +363,11 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
noise_b = torch.randn(
|
||||
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
|
||||
)
|
||||
noise = slerp(interpolate_alpha, noise_a, noise_b)
|
||||
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
|
||||
init_latents_orig = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
@ -295,7 +407,9 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if mask is not None:
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_orig, noise, torch.tensor([t])
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
|
@ -311,62 +425,42 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
return dict(images=image, latents=latents, nsfw_content_detected=False)
|
||||
|
||||
|
||||
def preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
||||
def preprocess_image(image: Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess an image for the model.
|
||||
"""
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
image = image.resize((w, h), resample=Image.LANCZOS)
|
||||
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = image_np[None].transpose(0, 3, 1, 2)
|
||||
|
||||
image_torch = torch.from_numpy(image_np)
|
||||
|
||||
return 2.0 * image_torch - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor:
|
||||
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess a mask for the model.
|
||||
"""
|
||||
# Convert to grayscale
|
||||
mask = mask.convert("L")
|
||||
|
||||
# Resize to integer multiple of 32
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize(
|
||||
(w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST
|
||||
)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
mask = 1 - mask # repaint white, keep black
|
||||
mask = torch.from_numpy(mask)
|
||||
w, h = map(lambda x: x - x % 32, (w, h))
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
|
||||
|
||||
return mask
|
||||
# Convert to numpy array and rescale
|
||||
mask_np = np.array(mask).astype(np.float32) / 255.0
|
||||
|
||||
# Tile and transpose
|
||||
mask_np = np.tile(mask_np, (4, 1, 1))
|
||||
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
|
||||
def slerp(t, v0, v1, dot_threshold=0.9995):
|
||||
"""
|
||||
Helper function to spherically interpolate two arrays v1 v2.
|
||||
"""
|
||||
# Invert to repaint white and keep black
|
||||
mask_np = 1 - mask_np # repaint white, keep black
|
||||
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > dot_threshold:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
|
||||
return v2
|
||||
return torch.from_numpy(mask_np)
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
"""
|
||||
Inference server for the riffusion project.
|
||||
Flask server that serves the riffusion model as an API.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import logging
|
||||
import io
|
||||
|
@ -16,15 +15,13 @@ import flask
|
|||
|
||||
from flask_cors import CORS
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .audio import wav_bytes_from_spectrogram_image
|
||||
from .audio import mp3_bytes_from_wav_bytes
|
||||
from .datatypes import InferenceInput
|
||||
from .datatypes import InferenceOutput
|
||||
from .riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.datatypes import InferenceOutput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import base64_util
|
||||
|
||||
# Flask app with CORS
|
||||
app = flask.Flask(__name__)
|
||||
|
@ -35,7 +32,7 @@ logging.basicConfig(level=logging.INFO)
|
|||
logging.getLogger().addHandler(logging.FileHandler("server.log"))
|
||||
|
||||
# Global variable for the model pipeline
|
||||
MODEL = None
|
||||
PIPELINE: T.Optional[RiffusionPipeline] = None
|
||||
|
||||
# Where built-in seed images are stored
|
||||
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
|
||||
|
@ -45,8 +42,9 @@ def run_app(
|
|||
*,
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
no_traced_unet: bool = False,
|
||||
device: str = "cuda",
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3000,
|
||||
port: int = 3013,
|
||||
debug: bool = False,
|
||||
ssl_certificate: T.Optional[str] = None,
|
||||
ssl_key: T.Optional[str] = None,
|
||||
|
@ -55,8 +53,12 @@ def run_app(
|
|||
Run a flask API that serves the given riffusion model checkpoint.
|
||||
"""
|
||||
# Initialize the model
|
||||
global MODEL
|
||||
MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)
|
||||
global PIPELINE
|
||||
PIPELINE = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=checkpoint,
|
||||
use_traced_unet=not no_traced_unet,
|
||||
device=device,
|
||||
)
|
||||
|
||||
args = dict(
|
||||
debug=debug,
|
||||
|
@ -69,51 +71,7 @@ def run_app(
|
|||
assert ssl_key is not None
|
||||
args["ssl_context"] = (ssl_certificate, ssl_key)
|
||||
|
||||
app.run(**args)
|
||||
|
||||
|
||||
def load_model(checkpoint: str, traced_unet: bool = True):
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
"""
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to("cuda")
|
||||
|
||||
# Set the traced unet if desired
|
||||
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
# Using traced unet from hf hub
|
||||
unet_file = hf_hub_download(
|
||||
checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = model.unet.in_channels
|
||||
self.device = model.unet.device
|
||||
self.dtype = torch.float16
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
model.unet = TracedUNet()
|
||||
|
||||
model = model.to("cuda")
|
||||
|
||||
return model
|
||||
app.run(**args) # type: ignore
|
||||
|
||||
|
||||
@app.route("/run_inference/", methods=["POST"])
|
||||
|
@ -145,7 +103,11 @@ def run_inference():
|
|||
logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
response = compute(inputs)
|
||||
response = compute_request(
|
||||
inputs=inputs,
|
||||
seed_images_dir=SEED_IMAGES_DIR,
|
||||
pipeline=PIPELINE,
|
||||
)
|
||||
|
||||
# Log the total time
|
||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||
|
@ -153,60 +115,73 @@ def run_inference():
|
|||
return response
|
||||
|
||||
|
||||
def compute(inputs: InferenceInput) -> str:
|
||||
def compute_request(
|
||||
inputs: InferenceInput,
|
||||
pipeline: RiffusionPipeline,
|
||||
seed_images_dir: str,
|
||||
) -> T.Union[str, T.Tuple[str, int]]:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
|
||||
Args:
|
||||
inputs: The input dataclass
|
||||
pipeline: The riffusion model pipeline
|
||||
seed_images_dir: The directory where seed images are stored
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
|
||||
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
|
||||
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
mask_image: T.Optional[PIL.Image.Image] = None
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
|
||||
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||
# TODO(hayk): It may help performance to cache this object
|
||||
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
|
||||
segment = converter.audio_from_spectrogram_image(
|
||||
image,
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
# Compute the output as base64 encoded strings
|
||||
image_bytes = image_bytes_from_image(image, mode="JPEG")
|
||||
# Export audio to MP3 bytes
|
||||
mp3_bytes = io.BytesIO()
|
||||
segment.export(mp3_bytes, format="mp3")
|
||||
mp3_bytes.seek(0)
|
||||
|
||||
# Export image to JPEG bytes
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, exif=image.getexif(), format="JPEG")
|
||||
image_bytes.seek(0)
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + base64_encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
|
||||
duration_s=duration_s,
|
||||
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
|
||||
duration_s=segment.duration_seconds,
|
||||
)
|
||||
|
||||
return flask.jsonify(dataclasses.asdict(output))
|
||||
|
||||
|
||||
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
"""
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, mode)
|
||||
image_bytes.seek(0)
|
||||
return image_bytes
|
||||
|
||||
|
||||
def base64_encode(buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
||||
return json.dumps(dataclasses.asdict(output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
import numpy as np
|
||||
import pydub
|
||||
import torch
|
||||
import torchaudio
|
||||
import warnings
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import audio_util
|
||||
from riffusion.util import torch_util
|
||||
|
||||
|
||||
class SpectrogramConverter:
|
||||
"""
|
||||
Convert between audio segments and spectrogram tensors using torchaudio.
|
||||
|
||||
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
|
||||
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
|
||||
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
|
||||
used in some functions is "mel amplitudes".
|
||||
|
||||
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
|
||||
returns the amplitude, because the phase is chaotic and hard to learn. The function
|
||||
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
|
||||
approximates the phase information using the Griffin-Lim algorithm.
|
||||
|
||||
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
|
||||
equal to the number of channels in the input audio segment.
|
||||
|
||||
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
|
||||
|
||||
For more information, see https://pytorch.org/audio/stable/transforms.html
|
||||
"""
|
||||
|
||||
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
self.p = params
|
||||
|
||||
self.device = torch_util.check_device(device)
|
||||
|
||||
if device.lower().startswith("mps"):
|
||||
warnings.warn(
|
||||
"WARNING: MPS does not support audio operations, falling back to CPU for them",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.device = "cpu"
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
|
||||
self.spectrogram_func = torchaudio.transforms.Spectrogram(
|
||||
n_fft=params.n_fft,
|
||||
hop_length=params.hop_length,
|
||||
win_length=params.win_length,
|
||||
pad=0,
|
||||
window_fn=torch.hann_window,
|
||||
power=None,
|
||||
normalized=False,
|
||||
wkwargs=None,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
|
||||
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
|
||||
n_fft=params.n_fft,
|
||||
n_iter=params.num_griffin_lim_iters,
|
||||
win_length=params.win_length,
|
||||
hop_length=params.hop_length,
|
||||
window_fn=torch.hann_window,
|
||||
power=1.0,
|
||||
wkwargs=None,
|
||||
momentum=0.99,
|
||||
length=None,
|
||||
rand_init=True,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
|
||||
self.mel_scaler = torchaudio.transforms.MelScale(
|
||||
n_mels=params.num_frequencies,
|
||||
sample_rate=params.sample_rate,
|
||||
f_min=params.min_frequency,
|
||||
f_max=params.max_frequency,
|
||||
n_stft=params.n_fft // 2 + 1,
|
||||
norm=params.mel_scale_norm,
|
||||
mel_scale=params.mel_scale_type,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
|
||||
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
|
||||
n_stft=params.n_fft // 2 + 1,
|
||||
n_mels=params.num_frequencies,
|
||||
sample_rate=params.sample_rate,
|
||||
f_min=params.min_frequency,
|
||||
f_max=params.max_frequency,
|
||||
max_iter=params.max_mel_iters,
|
||||
tolerance_loss=1e-5,
|
||||
tolerance_change=1e-8,
|
||||
sgdargs=None,
|
||||
norm=params.mel_scale_norm,
|
||||
mel_scale=params.mel_scale_type,
|
||||
).to(self.device)
|
||||
|
||||
def spectrogram_from_audio(
|
||||
self,
|
||||
audio: pydub.AudioSegment,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram from an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Audio segment which must match the sample rate of the params
|
||||
|
||||
Returns:
|
||||
spectrogram: (channel, frequency, time)
|
||||
"""
|
||||
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
|
||||
|
||||
# Get the samples as a numpy array in (batch, samples) shape
|
||||
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
|
||||
|
||||
# Convert to floats if necessary
|
||||
if waveform.dtype != np.float32:
|
||||
waveform = waveform.astype(np.float32)
|
||||
|
||||
waveform_tensor = torch.from_numpy(waveform).to(self.device)
|
||||
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
|
||||
return amplitudes_mel.cpu().numpy()
|
||||
|
||||
def audio_from_spectrogram(
|
||||
self,
|
||||
spectrogram: np.ndarray,
|
||||
apply_filters: bool = True,
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Reconstruct an audio segment from a spectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram: (batch, frequency, time)
|
||||
apply_filters: Post-process with normalization and compression
|
||||
|
||||
Returns:
|
||||
audio: Audio segment with channels equal to the batch dimension
|
||||
"""
|
||||
# Move to device
|
||||
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
|
||||
|
||||
# Reconstruct the waveform
|
||||
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
|
||||
|
||||
# Convert to audio segment
|
||||
segment = audio_util.audio_from_waveform(
|
||||
samples=waveform.cpu().numpy(),
|
||||
sample_rate=self.p.sample_rate,
|
||||
# Normalize the waveform to the range [-1, 1]
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
# Optionally apply post-processing filters
|
||||
if apply_filters:
|
||||
segment = audio_util.apply_filters(segment)
|
||||
|
||||
return segment
|
||||
|
||||
def mel_amplitudes_from_waveform(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Torch-only function to compute Mel-scale amplitudes from a waveform.
|
||||
|
||||
Args:
|
||||
waveform: (batch, samples)
|
||||
|
||||
Returns:
|
||||
amplitudes_mel: (batch, frequency, time)
|
||||
"""
|
||||
# Compute the complex-valued spectrogram
|
||||
spectrogram_complex = self.spectrogram_func(waveform)
|
||||
|
||||
# Take the magnitude
|
||||
amplitudes = torch.abs(spectrogram_complex)
|
||||
|
||||
# Convert to mel scale
|
||||
return self.mel_scaler(amplitudes)
|
||||
|
||||
def waveform_from_mel_amplitudes(
|
||||
self,
|
||||
amplitudes_mel: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
|
||||
|
||||
Args:
|
||||
amplitudes_mel: (batch, frequency, time)
|
||||
|
||||
Returns:
|
||||
waveform: (batch, samples)
|
||||
"""
|
||||
# Convert from mel scale to linear
|
||||
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
|
||||
|
||||
# Run the approximate algorithm to compute the phase and recover the waveform
|
||||
return self.inverse_spectrogram_func(amplitudes_linear)
|
|
@ -0,0 +1,91 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
|
||||
class SpectrogramImageConverter:
|
||||
"""
|
||||
Convert between spectrogram images and audio segments.
|
||||
|
||||
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
|
||||
to images and back. The real audio processing lives in SpectrogramConverter.
|
||||
"""
|
||||
|
||||
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
self.p = params
|
||||
self.device = device
|
||||
self.converter = SpectrogramConverter(params=params, device=device)
|
||||
|
||||
def spectrogram_image_from_audio(
|
||||
self,
|
||||
segment: pydub.AudioSegment,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Compute a spectrogram image from an audio segment.
|
||||
|
||||
Args:
|
||||
segment: Audio segment to convert
|
||||
|
||||
Returns:
|
||||
Spectrogram image (in pillow format)
|
||||
"""
|
||||
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
|
||||
|
||||
if self.p.stereo:
|
||||
if segment.channels == 1:
|
||||
print("WARNING: Mono audio but stereo=True, cloning channel")
|
||||
segment = segment.set_channels(2)
|
||||
elif segment.channels > 2:
|
||||
print("WARNING: Multi channel audio, reducing to stereo")
|
||||
segment = segment.set_channels(2)
|
||||
else:
|
||||
if segment.channels > 1:
|
||||
print("WARNING: Stereo audio but stereo=False, setting to mono")
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
spectrogram = self.converter.spectrogram_from_audio(segment)
|
||||
|
||||
image = image_util.image_from_spectrogram(
|
||||
spectrogram,
|
||||
power=self.p.power_for_image,
|
||||
)
|
||||
|
||||
# Store conversion params in exif metadata of the image
|
||||
exif_data = self.p.to_exif()
|
||||
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
|
||||
exif = image.getexif()
|
||||
exif.update(exif_data.items())
|
||||
|
||||
return image
|
||||
|
||||
def audio_from_spectrogram_image(
|
||||
self,
|
||||
image: Image.Image,
|
||||
apply_filters: bool = True,
|
||||
max_value: float = 30e6,
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Reconstruct an audio segment from a spectrogram image.
|
||||
|
||||
Args:
|
||||
image: Spectrogram image (in pillow format)
|
||||
apply_filters: Apply post-processing to improve the reconstructed audio
|
||||
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
|
||||
"""
|
||||
spectrogram = image_util.spectrogram_from_image(
|
||||
image,
|
||||
max_value=max_value,
|
||||
power=self.p.power_for_image,
|
||||
stereo=self.p.stereo,
|
||||
)
|
||||
|
||||
segment = self.converter.audio_from_spectrogram(
|
||||
spectrogram,
|
||||
apply_filters=apply_filters,
|
||||
)
|
||||
|
||||
return segment
|
|
@ -0,0 +1,112 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import typing as T
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpectrogramParams:
|
||||
"""
|
||||
Parameters for the conversion from audio to spectrograms to images and back.
|
||||
|
||||
Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
|
||||
within spectrogram images.
|
||||
"""
|
||||
|
||||
# Whether the audio is stereo or mono
|
||||
stereo: bool = False
|
||||
|
||||
# FFT parameters
|
||||
sample_rate: int = 44100
|
||||
step_size_ms: int = 10
|
||||
window_duration_ms: int = 100
|
||||
padded_duration_ms: int = 400
|
||||
|
||||
# Mel scale parameters
|
||||
num_frequencies: int = 512
|
||||
# TODO(hayk): Set these to [20, 20000] for newer models
|
||||
min_frequency: int = 0
|
||||
max_frequency: int = 10000
|
||||
mel_scale_norm: T.Optional[str] = None
|
||||
mel_scale_type: str = "htk"
|
||||
max_mel_iters: int = 200
|
||||
|
||||
# Griffin Lim parameters
|
||||
num_griffin_lim_iters: int = 32
|
||||
|
||||
# Image parameterization
|
||||
power_for_image: float = 0.25
|
||||
|
||||
class ExifTags(Enum):
|
||||
"""
|
||||
Custom EXIF tags for the spectrogram image.
|
||||
"""
|
||||
|
||||
SAMPLE_RATE = 11000
|
||||
STEREO = 11005
|
||||
STEP_SIZE_MS = 11010
|
||||
WINDOW_DURATION_MS = 11020
|
||||
PADDED_DURATION_MS = 11030
|
||||
|
||||
NUM_FREQUENCIES = 11040
|
||||
MIN_FREQUENCY = 11050
|
||||
MAX_FREQUENCY = 11060
|
||||
|
||||
POWER_FOR_IMAGE = 11070
|
||||
MAX_VALUE = 11080
|
||||
|
||||
@property
|
||||
def n_fft(self) -> int:
|
||||
"""
|
||||
The number of samples in each STFT window, with padding.
|
||||
"""
|
||||
return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
|
||||
|
||||
@property
|
||||
def win_length(self) -> int:
|
||||
"""
|
||||
The number of samples in each STFT window.
|
||||
"""
|
||||
return int(self.window_duration_ms / 1000.0 * self.sample_rate)
|
||||
|
||||
@property
|
||||
def hop_length(self) -> int:
|
||||
"""
|
||||
The number of samples between each STFT window.
|
||||
"""
|
||||
return int(self.step_size_ms / 1000.0 * self.sample_rate)
|
||||
|
||||
def to_exif(self) -> T.Dict[int, T.Any]:
|
||||
"""
|
||||
Return a dictionary of EXIF tags for the current values.
|
||||
"""
|
||||
return {
|
||||
self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
|
||||
self.ExifTags.STEREO.value: self.stereo,
|
||||
self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
|
||||
self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
|
||||
self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
|
||||
self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
|
||||
self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
|
||||
self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
|
||||
self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_exif(cls, exif: T.Mapping[int, T.Any]) -> SpectrogramParams:
|
||||
"""
|
||||
Create a SpectrogramParams object from the EXIF tags of the given image.
|
||||
"""
|
||||
# TODO(hayk): Handle missing tags
|
||||
return cls(
|
||||
sample_rate=exif[cls.ExifTags.SAMPLE_RATE.value],
|
||||
stereo=bool(exif[cls.ExifTags.STEREO.value]),
|
||||
step_size_ms=exif[cls.ExifTags.STEP_SIZE_MS.value],
|
||||
window_duration_ms=exif[cls.ExifTags.WINDOW_DURATION_MS.value],
|
||||
padded_duration_ms=exif[cls.ExifTags.PADDED_DURATION_MS.value],
|
||||
num_frequencies=exif[cls.ExifTags.NUM_FREQUENCIES.value],
|
||||
min_frequency=exif[cls.ExifTags.MIN_FREQUENCY.value],
|
||||
max_frequency=exif[cls.ExifTags.MAX_FREQUENCY.value],
|
||||
power_for_image=exif[cls.ExifTags.POWER_FOR_IMAGE.value],
|
||||
)
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Audio utility functions.
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def audio_from_waveform(
|
||||
samples: np.ndarray, sample_rate: int, normalize: bool = False
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Convert a numpy array of samples of a waveform to an audio segment.
|
||||
"""
|
||||
# Normalize volume to fit in int16
|
||||
if normalize:
|
||||
samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
|
||||
|
||||
# Transpose and convert to int16
|
||||
samples = samples.transpose(1, 0)
|
||||
samples = samples.astype(np.int16)
|
||||
|
||||
# Write to the bytes of a WAV file
|
||||
wav_bytes = io.BytesIO()
|
||||
wavfile.write(wav_bytes, sample_rate, samples)
|
||||
wav_bytes.seek(0)
|
||||
|
||||
# Read into pydub
|
||||
return pydub.AudioSegment.from_wav(wav_bytes)
|
||||
|
||||
|
||||
def apply_filters(segment: pydub.AudioSegment) -> pydub.AudioSegment:
|
||||
"""
|
||||
Apply post-processing filters to the audio segment to compress it and
|
||||
keep at a -10 dBFS level.
|
||||
"""
|
||||
# TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
|
||||
# TODO(hayk): Is this going to make audio unbalanced between sequential clips?
|
||||
|
||||
segment = pydub.effects.normalize(
|
||||
segment,
|
||||
headroom=0.1,
|
||||
)
|
||||
|
||||
segment = segment.apply_gain(-10 - segment.dBFS)
|
||||
|
||||
segment = pydub.effects.compress_dynamic_range(
|
||||
segment,
|
||||
threshold=-20.0,
|
||||
ratio=4.0,
|
||||
attack=5.0,
|
||||
release=50.0,
|
||||
)
|
||||
|
||||
desired_db = -12
|
||||
segment = segment.apply_gain(desired_db - segment.dBFS)
|
||||
|
||||
segment = pydub.effects.normalize(
|
||||
segment,
|
||||
headroom=0.1,
|
||||
)
|
||||
|
||||
return segment
|
|
@ -0,0 +1,9 @@
|
|||
import base64
|
||||
import io
|
||||
|
||||
|
||||
def encode(buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
FFT tools to analyze frequency content of audio segments. This is not code for
|
||||
dealing with spectrogram images, but for analysis of waveforms.
|
||||
"""
|
||||
import struct
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import pydub
|
||||
from scipy.fft import rfft, rfftfreq
|
||||
|
||||
|
||||
def plot_ffts(
|
||||
segments: T.Dict[str, pydub.AudioSegment],
|
||||
title: str = "FFT",
|
||||
min_frequency: float = 20,
|
||||
max_frequency: float = 20000,
|
||||
) -> None:
|
||||
"""
|
||||
Plot an FFT analysis of the given audio segments.
|
||||
"""
|
||||
ffts = {name: compute_fft(seg) for name, seg in segments.items()}
|
||||
|
||||
fig = go.Figure(
|
||||
data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
|
||||
layout={"title": title},
|
||||
)
|
||||
fig.update_xaxes(
|
||||
range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
|
||||
type="log",
|
||||
title="Frequency",
|
||||
)
|
||||
fig.update_yaxes(title="Value")
|
||||
fig.show()
|
||||
|
||||
|
||||
def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute the FFT of the given audio segment as a mono signal.
|
||||
|
||||
Returns:
|
||||
frequencies: FFT computed frequencies
|
||||
amplitudes: Amplitudes of each frequency
|
||||
"""
|
||||
# Convert to mono if needed.
|
||||
if sound.channels > 1:
|
||||
sound = sound.set_channels(1)
|
||||
|
||||
sample_rate = sound.frame_rate
|
||||
|
||||
num_samples = int(sound.frame_count())
|
||||
samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
|
||||
|
||||
fft_values = rfft(samples)
|
||||
amplitudes = np.abs(fft_values)
|
||||
|
||||
frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
|
||||
|
||||
return frequencies, amplitudes
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Module for converting between spectrograms tensors and spectrogram images, as well as
|
||||
general helpers for operating on pillow images.
|
||||
"""
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
|
||||
def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
|
||||
"""
|
||||
Compute a spectrogram image from a spectrogram magnitude array.
|
||||
|
||||
This is the inverse of spectrogram_from_image, except for discretization error from
|
||||
quantizing to uint8.
|
||||
|
||||
Args:
|
||||
spectrogram: (channels, frequency, time)
|
||||
power: A power curve to apply to the spectrogram to preserve contrast
|
||||
|
||||
Returns:
|
||||
image: (frequency, time, channels)
|
||||
"""
|
||||
# Rescale to 0-1
|
||||
max_value = np.max(spectrogram)
|
||||
data = spectrogram / max_value
|
||||
|
||||
# Apply the power curve
|
||||
data = np.power(data, power)
|
||||
|
||||
# Rescale to 0-255
|
||||
data = data * 255
|
||||
|
||||
# Invert
|
||||
data = 255 - data
|
||||
|
||||
# Convert to uint8
|
||||
data = data.astype(np.uint8)
|
||||
|
||||
# Munge channels into a PIL image
|
||||
if data.shape[0] == 1:
|
||||
# TODO(hayk): Do we want to write single channel to disk instead?
|
||||
image = Image.fromarray(data[0], mode="L").convert("RGB")
|
||||
elif data.shape[0] == 2:
|
||||
data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
|
||||
image = Image.fromarray(data, mode="RGB")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
|
||||
|
||||
# Flip Y
|
||||
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def spectrogram_from_image(
|
||||
image: Image.Image,
|
||||
power: float = 0.25,
|
||||
stereo: bool = False,
|
||||
max_value: float = 30e6,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram magnitude array from a spectrogram image.
|
||||
|
||||
This is the inverse of image_from_spectrogram, except for discretization error from
|
||||
quantizing to uint8.
|
||||
|
||||
Args:
|
||||
image: (frequency, time, channels)
|
||||
power: The power curve applied to the spectrogram
|
||||
stereo: Whether the spectrogram encodes stereo data
|
||||
max_value: The max value of the original spectrogram. In practice doesn't matter.
|
||||
|
||||
Returns:
|
||||
spectrogram: (channels, frequency, time)
|
||||
"""
|
||||
# Flip Y
|
||||
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Munge channels into a numpy array of (channels, frequency, time)
|
||||
data = np.array(image).transpose(2, 0, 1)
|
||||
if stereo:
|
||||
# Take the G and B channels as done in image_from_spectrogram
|
||||
data = data[[1, 2], :, :]
|
||||
else:
|
||||
data = data[0:1, :, :]
|
||||
|
||||
# Convert to floats
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# Invert
|
||||
data = 255 - data
|
||||
|
||||
# Rescale to 0-1
|
||||
data = data / 255
|
||||
|
||||
# Reverse the power curve
|
||||
data = np.power(data, 1 / power)
|
||||
|
||||
# Rescale to max value
|
||||
data = data * max_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]:
|
||||
"""
|
||||
Get the EXIF data from a PIL image as a dict.
|
||||
"""
|
||||
exif = pil_image.getexif()
|
||||
|
||||
if exif is None or len(exif) == 0:
|
||||
return {}
|
||||
|
||||
return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()}
|
|
@ -0,0 +1,48 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def check_device(device: str, backup: str = "cpu") -> str:
|
||||
"""
|
||||
Check that the device is valid and available. If not,
|
||||
"""
|
||||
cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
|
||||
mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
|
||||
|
||||
if cuda_not_found or mps_not_found:
|
||||
warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
|
||||
return backup
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def slerp(
|
||||
t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function to spherically interpolate two arrays v1 v2.
|
||||
"""
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > dot_threshold:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
|
||||
return v2
|
|
@ -0,0 +1,99 @@
|
|||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.cli import audio_to_image
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class AudioToImageTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli audio-to-image
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def default_params(cls) -> T.Dict:
|
||||
return dict(
|
||||
step_size_ms=10,
|
||||
num_frequencies=512,
|
||||
# TODO(hayk): Change these to [20, 20000] once a model is updated
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
stereo=False,
|
||||
device=cls.DEVICE,
|
||||
)
|
||||
|
||||
def test_audio_to_image(self) -> None:
|
||||
"""
|
||||
Test audio-to-image with default params.
|
||||
"""
|
||||
params = self.default_params()
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_stereo(self) -> None:
|
||||
"""
|
||||
Test audio-to-image with stereo=True.
|
||||
"""
|
||||
params = self.default_params()
|
||||
params["stereo"] = True
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def helper_test_with_params(self, params: T.Dict) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir("audio_to_image_")
|
||||
|
||||
if params["stereo"]:
|
||||
stem = f"{audio_path.stem}_stereo"
|
||||
else:
|
||||
stem = audio_path.stem
|
||||
|
||||
image_path = output_dir / f"{stem}.png"
|
||||
|
||||
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
|
||||
|
||||
# Check that the image exists
|
||||
self.assertTrue(image_path.exists())
|
||||
|
||||
pil_image = Image.open(image_path)
|
||||
|
||||
# Check the image mode
|
||||
self.assertEqual(pil_image.mode, "RGB")
|
||||
|
||||
# Check the image dimensions
|
||||
duration_ms = 5678
|
||||
self.assertTrue(str(duration_ms) in audio_path.name)
|
||||
expected_image_width = round(duration_ms / params["step_size_ms"])
|
||||
self.assertEqual(pil_image.width, expected_image_width)
|
||||
self.assertEqual(pil_image.height, params["num_frequencies"])
|
||||
|
||||
# Get channels as numpy arrays
|
||||
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
|
||||
self.assertEqual(len(channels), 3)
|
||||
|
||||
if params["stereo"]:
|
||||
# Check that the first channel is zero
|
||||
self.assertTrue(np.all(channels[0] == 0))
|
||||
else:
|
||||
# Check that all channels are the same
|
||||
self.assertTrue(np.all(channels[0] == channels[1]))
|
||||
self.assertTrue(np.all(channels[0] == channels[2]))
|
||||
|
||||
# Check that the image has exif data
|
||||
exif = pil_image.getexif()
|
||||
self.assertIsNotNone(exif)
|
||||
params_from_exif = SpectrogramParams.from_exif(exif)
|
||||
expected_params = SpectrogramParams(
|
||||
stereo=params["stereo"],
|
||||
step_size_ms=params["step_size_ms"],
|
||||
num_frequencies=params["num_frequencies"],
|
||||
max_frequency=params["max_frequency"],
|
||||
)
|
||||
self.assertTrue(params_from_exif == expected_params)
|
|
@ -0,0 +1,71 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.cli import image_to_audio
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class ImageToAudioTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli image-to-audio
|
||||
"""
|
||||
|
||||
def test_image_to_audio_mono(self) -> None:
|
||||
self.helper_image_to_audio(
|
||||
song_dir=self.TEST_DATA_PATH / "tired_traveler",
|
||||
clip_name="clip_2_start_103694_ms_duration_5678_ms",
|
||||
stereo=False,
|
||||
)
|
||||
|
||||
def test_image_to_audio_stereo(self) -> None:
|
||||
self.helper_image_to_audio(
|
||||
song_dir=self.TEST_DATA_PATH / "tired_traveler",
|
||||
clip_name="clip_2_start_103694_ms_duration_5678_ms",
|
||||
stereo=True,
|
||||
)
|
||||
|
||||
def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
|
||||
if stereo:
|
||||
image_stem = clip_name + "_stereo"
|
||||
else:
|
||||
image_stem = clip_name
|
||||
|
||||
image_path = song_dir / "images" / f"{image_stem}.png"
|
||||
output_dir = self.get_tmp_dir("image_to_audio_")
|
||||
audio_path = output_dir / f"{image_path.stem}.wav"
|
||||
|
||||
image_to_audio(
|
||||
image=str(image_path),
|
||||
audio=str(audio_path),
|
||||
device=self.DEVICE,
|
||||
)
|
||||
|
||||
# Check that the audio exists
|
||||
self.assertTrue(audio_path.exists())
|
||||
|
||||
# Load the reconstructed audio and the original clip
|
||||
segment = pydub.AudioSegment.from_file(str(audio_path))
|
||||
expected_segment = pydub.AudioSegment.from_file(
|
||||
str(song_dir / "clips" / f"{clip_name}.wav")
|
||||
)
|
||||
|
||||
# Check sample rate
|
||||
self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
|
||||
|
||||
# Check duration
|
||||
actual_duration_ms = round(segment.duration_seconds * 1000)
|
||||
expected_duration_ms = round(expected_segment.duration_seconds * 1000)
|
||||
self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
|
||||
|
||||
# Check the number of channels
|
||||
self.assertEqual(expected_segment.channels, 2)
|
||||
if stereo:
|
||||
self.assertEqual(segment.channels, 2)
|
||||
else:
|
||||
self.assertEqual(segment.channels, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase.main()
|
|
@ -0,0 +1,65 @@
|
|||
import numpy as np
|
||||
import pydub
|
||||
|
||||
from riffusion.util import image_util
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class ImageUtilTest(TestCase):
|
||||
"""
|
||||
Test riffusion.util.image_util
|
||||
"""
|
||||
|
||||
def test_spectrogram_to_image_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Compute a spectrogram with default params
|
||||
params = SpectrogramParams(sample_rate=segment.frame_rate)
|
||||
converter = SpectrogramConverter(params=params, device=self.DEVICE)
|
||||
spectrogram = converter.spectrogram_from_audio(segment)
|
||||
|
||||
# Compute the image from the spectrogram
|
||||
image = image_util.image_from_spectrogram(
|
||||
spectrogram=spectrogram,
|
||||
power=params.power_for_image,
|
||||
)
|
||||
|
||||
# Save the max value
|
||||
max_value = np.max(spectrogram)
|
||||
|
||||
# Compute the spectrogram from the image
|
||||
spectrogram_reversed = image_util.spectrogram_from_image(
|
||||
image=image,
|
||||
max_value=max_value,
|
||||
power=params.power_for_image,
|
||||
stereo=params.stereo,
|
||||
)
|
||||
|
||||
# Check the shapes
|
||||
self.assertEqual(spectrogram.shape, spectrogram_reversed.shape)
|
||||
|
||||
# Check the max values
|
||||
self.assertEqual(np.max(spectrogram), np.max(spectrogram_reversed))
|
||||
|
||||
# Check the median values
|
||||
self.assertTrue(
|
||||
np.allclose(np.median(spectrogram), np.median(spectrogram_reversed), rtol=0.05)
|
||||
)
|
||||
|
||||
# Make sure all values are somewhat similar, but allow for discretization error
|
||||
# TODO(hayk): Investigate error more closely
|
||||
self.assertTrue(np.allclose(spectrogram, spectrogram_reversed, rtol=0.15))
|
|
@ -0,0 +1,24 @@
|
|||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class LinterTest(TestCase):
|
||||
"""
|
||||
Test that ruff, black, and mypy run cleanly.
|
||||
"""
|
||||
|
||||
HOME = Path(__file__).parent.parent
|
||||
|
||||
def test_ruff(self) -> None:
|
||||
code = subprocess.check_call(["ruff", str(self.HOME)])
|
||||
self.assertEqual(code, 0)
|
||||
|
||||
def test_black(self) -> None:
|
||||
code = subprocess.check_call(["black", "--check", str(self.HOME)])
|
||||
self.assertEqual(code, 0)
|
||||
|
||||
def test_mypy(self) -> None:
|
||||
code = subprocess.check_call(["mypy", str(self.HOME)])
|
||||
self.assertEqual(code, 0)
|
|
@ -0,0 +1,32 @@
|
|||
import contextlib
|
||||
import io
|
||||
|
||||
from riffusion.cli import print_exif
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class PrintExifTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli print-exif
|
||||
"""
|
||||
|
||||
def test_print_exif(self) -> None:
|
||||
"""
|
||||
Test print-exif.
|
||||
"""
|
||||
image_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "images"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.png"
|
||||
)
|
||||
|
||||
# Redirect stdout
|
||||
stdout = io.StringIO()
|
||||
with contextlib.redirect_stdout(stdout):
|
||||
print_exif(image=str(image_path))
|
||||
|
||||
# Check that a couple of values are printed
|
||||
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
|
||||
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())
|
|
@ -0,0 +1,88 @@
|
|||
import typing as T
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.cli import sample_clips
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SampleClipsTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli sample-clips
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def default_params() -> T.Dict:
|
||||
return dict(
|
||||
num_clips=3,
|
||||
duration_ms=5678,
|
||||
mono=False,
|
||||
extension="wav",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
def test_sample_clips(self) -> None:
|
||||
"""
|
||||
Test sample-clips with default params.
|
||||
"""
|
||||
params = self.default_params()
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_mono(self) -> None:
|
||||
"""
|
||||
Test sample-clips with mono=True.
|
||||
"""
|
||||
params = self.default_params()
|
||||
params["mono"] = True
|
||||
params["num_clips"] = 1
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_mp3(self) -> None:
|
||||
"""
|
||||
Test sample-clips with extension=mp3.
|
||||
"""
|
||||
if pydub.AudioSegment.converter is None:
|
||||
self.skipTest("skipping, ffmpeg not found")
|
||||
|
||||
params = self.default_params()
|
||||
params["extension"] = "mp3"
|
||||
params["num_clips"] = 1
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def helper_test_with_params(self, params: T.Dict) -> None:
|
||||
"""
|
||||
Test sample-clips with the given params.
|
||||
"""
|
||||
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
|
||||
output_dir = self.get_tmp_dir("sample_clips_")
|
||||
|
||||
sample_clips(
|
||||
audio=str(audio_path),
|
||||
output_dir=str(output_dir),
|
||||
**params,
|
||||
)
|
||||
|
||||
# For each file in output dir
|
||||
counter = 0
|
||||
for clip_path in output_dir.iterdir():
|
||||
# Check that it has the right extension
|
||||
self.assertEqual(clip_path.suffix, f".{params['extension']}")
|
||||
|
||||
# Check that it has the right duration
|
||||
segment = pydub.AudioSegment.from_file(clip_path)
|
||||
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
|
||||
|
||||
# Check that it has the right number of channels
|
||||
if params["mono"]:
|
||||
self.assertEqual(segment.channels, 1)
|
||||
else:
|
||||
self.assertEqual(segment.channels, 2)
|
||||
|
||||
counter += 1
|
||||
|
||||
self.assertEqual(counter, params["num_clips"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase.main()
|
|
@ -0,0 +1,86 @@
|
|||
import dataclasses
|
||||
import typing as T
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import fft_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SpectrogramConverterTest(TestCase):
|
||||
"""
|
||||
Test going from audio to spectrogram to audio, without converting to
|
||||
an image, to check quality loss of the reconstruction.
|
||||
|
||||
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
||||
and by plotting their FFTs.
|
||||
"""
|
||||
|
||||
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
|
||||
|
||||
def test_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono if desired
|
||||
use_stereo = False
|
||||
if use_stereo:
|
||||
assert segment.channels == 2
|
||||
else:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Define named sets of parameters
|
||||
param_sets: T.Dict[str, SpectrogramParams] = {}
|
||||
|
||||
param_sets["default"] = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=use_stereo,
|
||||
step_size_ms=10,
|
||||
min_frequency=20,
|
||||
max_frequency=20000,
|
||||
num_frequencies=512,
|
||||
)
|
||||
|
||||
if self.DEBUG:
|
||||
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
||||
param_sets["default"],
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segments: T.Dict[str, pydub.AudioSegment] = {
|
||||
"original": segment,
|
||||
}
|
||||
for name, params in param_sets.items():
|
||||
converter = SpectrogramConverter(params=params, device=self.DEVICE)
|
||||
spectrogram = converter.spectrogram_from_audio(segment)
|
||||
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
|
||||
|
||||
# Save segments to disk
|
||||
for name, segment in segments.items():
|
||||
audio_out = output_dir / f"{name}.wav"
|
||||
segment.export(audio_out, format="wav")
|
||||
print(f"Saved {audio_out}")
|
||||
|
||||
# Check params
|
||||
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
||||
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
||||
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
||||
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
||||
|
||||
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
||||
|
||||
# If debugging, load up a browser tab plotting the FFTs
|
||||
if self.DEBUG:
|
||||
fft_util.plot_ffts(segments)
|
|
@ -0,0 +1,97 @@
|
|||
import dataclasses
|
||||
import typing as T
|
||||
|
||||
from PIL import Image
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import fft_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SpectrogramImageConverterTest(TestCase):
|
||||
"""
|
||||
Test going from audio to spectrogram images to audio, testing the quality loss of the
|
||||
end-to-end pipeline.
|
||||
|
||||
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
||||
and by plotting their FFTs.
|
||||
|
||||
See spectrogram_converter_test.py for a similar test that does not convert to images.
|
||||
"""
|
||||
|
||||
def test_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_")
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono if desired
|
||||
use_stereo = False
|
||||
if use_stereo:
|
||||
assert segment.channels == 2
|
||||
else:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Define named sets of parameters
|
||||
param_sets: T.Dict[str, SpectrogramParams] = {}
|
||||
|
||||
param_sets["default"] = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=use_stereo,
|
||||
step_size_ms=10,
|
||||
min_frequency=20,
|
||||
max_frequency=20000,
|
||||
num_frequencies=512,
|
||||
)
|
||||
|
||||
if self.DEBUG:
|
||||
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
||||
param_sets["default"],
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segments: T.Dict[str, pydub.AudioSegment] = {
|
||||
"original": segment,
|
||||
}
|
||||
images: T.Dict[str, Image.Image] = {}
|
||||
for name, params in param_sets.items():
|
||||
converter = SpectrogramImageConverter(params=params, device=self.DEVICE)
|
||||
images[name] = converter.spectrogram_image_from_audio(segment)
|
||||
segments[name] = converter.audio_from_spectrogram_image(
|
||||
image=images[name],
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
# Save images to disk
|
||||
for name, image in images.items():
|
||||
image_out = output_dir / f"{name}.png"
|
||||
image.save(image_out, exif=image.getexif(), format="PNG")
|
||||
print(f"Saved {image_out}")
|
||||
|
||||
# Save segments to disk
|
||||
for name, segment in segments.items():
|
||||
audio_out = output_dir / f"{name}.wav"
|
||||
segment.export(audio_out, format="wav")
|
||||
print(f"Saved {audio_out}")
|
||||
|
||||
# Check params
|
||||
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
||||
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
||||
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
||||
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
||||
|
||||
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
||||
|
||||
# If debugging, load up a browser tab plotting the FFTs
|
||||
if self.DEBUG:
|
||||
fft_util.plot_ffts(segments)
|
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import tempfile
|
||||
import typing as T
|
||||
import warnings
|
||||
import unittest
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
"""
|
||||
Base class for tests.
|
||||
"""
|
||||
|
||||
# Where checked-in test data is stored
|
||||
TEST_DATA_PATH = Path(__file__).resolve().parent / "test_data"
|
||||
|
||||
# Whether to run tests in debug mode (e.g. don't clean up temporary directories, show plots)
|
||||
DEBUG = bool(os.environ.get("RIFFUSION_TEST_DEBUG"))
|
||||
|
||||
# Which torch device to use for tests
|
||||
DEVICE = os.environ.get("RIFFUSION_TEST_DEVICE", "cuda")
|
||||
|
||||
@staticmethod
|
||||
def main(*args: T.Any, **kwargs: T.Any) -> None:
|
||||
"""
|
||||
Run the tests.
|
||||
"""
|
||||
unittest.main(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning)
|
||||
|
||||
def get_tmp_dir(self, prefix: str) -> Path:
|
||||
"""
|
||||
Create a temporary directory.
|
||||
"""
|
||||
tmp_dir = tempfile.mkdtemp(prefix=prefix)
|
||||
|
||||
# Clean up the temporary directory if not debugging
|
||||
if not self.DEBUG:
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir, ignore_errors=True))
|
||||
|
||||
dir_path = Path(tmp_dir)
|
||||
assert dir_path.is_dir()
|
||||
|
||||
return dir_path
|
|
@ -0,0 +1,7 @@
|
|||
# Test Data
|
||||
|
||||
### tired_traveler
|
||||
|
||||
* Song: Tired traveler on the way to home
|
||||
* Artist: Andrew Codeman
|
||||
* Source: https://freemusicarchive.org/
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 258 KiB |
Binary file not shown.
After Width: | Height: | Size: 382 KiB |
Binary file not shown.
Loading…
Reference in New Issue