revup virtual diff target

2fc9b8d379
d98efa55fe6e64c65f3344a22cad8db1111240ff
450a9b6fad
a7b41d90bd
This commit is contained in:
Hayk Martiros 2023-01-17 02:17:23 +00:00
commit 93d90b3832
75 changed files with 6483 additions and 0 deletions

44
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,44 @@
name: CI
run-name: ${{ github.actor }} is running Riffusion CI
on:
push:
branches:
- 'main'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
jobs:
riffusion-ci:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10"]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install system packages
run: |
sudo apt-get update
sudo apt-get install -y ffmpeg libsndfile1
- name: Install pip packages from requirements.txt
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Install pip packages from dev_requirements.txt
run: |
pip install -r dev_requirements.txt
- name: Test with unittest
run: |
RIFFUSION_TEST_DEVICE=cpu python -m unittest test/*_test.py

141
.gitignore vendored Normal file
View File

@ -0,0 +1,141 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# VSCode
.vscode
# Cog
.cog/
# Random stuff I don't care about
.graveyard/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# OSX cruft
.DS_Store
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

6
CITATION Normal file
View File

@ -0,0 +1,6 @@
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}

16
LICENSE Normal file
View File

@ -0,0 +1,16 @@
Copyright 2022 Hayk Martiros and Seth Forsgren
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

220
README.md Normal file
View File

@ -0,0 +1,220 @@
# :guitar: Riffusion
<a href="https://github.com/riffusion/riffusion/actions/workflows/ci.yml?query=branch%3Amain"><img alt="CI status" src="https://github.com/riffusion/riffusion/actions/workflows/ci.yml/badge.svg" /></a>
<img alt="Python 3.9 | 3.10" src="https://img.shields.io/badge/Python-3.9%20%7C%203.10-blue" />
<a href="https://github.com/riffusion/riffusion/tree/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/License-MIT-yellowgreen" /></a>
Riffusion is a library for real-time music and audio generation with stable diffusion.
Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
This is the core repository for riffusion image and audio processing code.
* Diffusion pipeline that performs prompt interpolation combined with image conditioning
* Conversions between spectrogram images and audio clips
* Command-line interface for common tasks
* Interactive app using streamlit
* Flask server to provide model inference via API
* Various third party integrations
Related repositories:
* Web app: https://github.com/riffusion/riffusion-app
* Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
## Citation
If you build on this work, please cite it as follows:
```
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}
```
## Install
Tested in CI with Python 3.9 and 3.10.
It's highly recommended to set up a virtual Python environment with `conda` or `virtualenv`:
```
conda create --name riffusion python=3.9
conda activate riffusion
```
Install Python dependencies:
```
python -m pip install -r requirements.txt
```
In order to use audio formats other than WAV, [ffmpeg](https://ffmpeg.org/download.html) is required.
```
sudo apt-get install ffmpeg # linux
brew install ffmpeg # mac
conda install -c conda-forge ffmpeg # conda
```
If torchaudio has no backend, you may need to install `libsndfile`. See [this issue](https://github.com/riffusion/riffusion/issues/12).
If you have an issue, try upgrading [diffusers](https://github.com/huggingface/diffusers). Tested with 0.9 - 0.11.
Guides:
* [Simple Install Guide for Windows](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
## Backends
### CPU
`cpu` is supported but is quite slow.
### CUDA
`cuda` is the recommended and most performant backend.
To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
[install guide](https://pytorch.org/get-started/locally/) or
[stable wheels](https://download.pytorch.org/whl/torch_stable.html).
To generate audio in real-time, you need a GPU that can run stable diffusion with approximately 50
steps in under five seconds, such as a 3090 or A10G.
Test availability with:
```python3
import torch
torch.cuda.is_available()
```
### MPS
The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
particularly for audio processing. You may need to set
`PYTORCH_ENABLE_MPS_FALLBACK=1`.
In addition, this backend is not deterministic.
Test availability with:
```python3
import torch
torch.backends.mps.is_available()
```
## Command-line interface
Riffusion comes with a command line interface for performing common tasks.
See available commands:
```
python -m riffusion.cli -h
```
Get help for a specific command:
```
python -m riffusion.cli image-to-audio -h
```
Execute:
```
python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
```
## Riffusion Playground
Riffusion contains a [streamlit](https://streamlit.io/) app for interactive use and exploration.
Run with:
```
python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --browser.serverPort 8501
```
And access at http://127.0.0.1:8501/
<img alt="Riffusion Playground" style="width: 600px" src="https://i.imgur.com/OOMKBbT.png" />
## Run the model server
Riffusion can be run as a flask server that provides inference via API. This server enables the [web app](https://github.com/riffusion/riffusion-app) to run locally.
Run with:
```
python -m riffusion.server --host 127.0.0.1 --port 3013
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
Use the `--device` argument to specify the torch device to use.
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
```
{
"alpha": 0.75,
"num_inference_steps": 50,
"seed_image_id": "og_beat",
"start": {
"prompt": "church bells on sunday",
"seed": 42,
"denoising": 0.75,
"guidance": 7.0
},
"end": {
"prompt": "jazz with piano",
"seed": 123,
"denoising": 0.75,
"guidance": 7.0
}
}
```
Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L54) for the API):
```
{
"image": "< base64 encoded JPEG image >",
"audio": "< base64 encoded MP3 clip >"
}
```
## Tests
Tests live in the `test/` directory and are implemented with `unittest`.
To run all tests:
```
python -m unittest test/*_test.py
```
To run a single test:
```
python -m unittest test.audio_to_image_test
```
To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
```
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
```
To run a single test case within a test:
```
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
```
To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with
`cpu`, `cuda`, and `mps` backends.
## Development Guide
Install additional packages for dev with `python -m pip install -r dev_requirements.txt`.
* Linter: `ruff`
* Formatter: `black`
* Type checker: `mypy`
These are configured in `pyproject.toml`.
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
CI is run through GitHub Actions from `.github/workflows/ci.yml`.
Contributions are welcome through pull requests.

38
cog.yaml Normal file
View File

@ -0,0 +1,38 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
build:
# set to true if your model requires a GPU
gpu: true
# a list of ubuntu apt packages to install
system_packages:
- "ffmpeg"
- "libsndfile1"
# python version in the form '3.8' or '3.8.12'
python_version: "3.9"
# a list of packages in the format <package-name>==<version>
python_packages:
- "accelerate==0.15.0"
- "argh==0.26.2"
- "dacite==1.6.0"
- "diffusers==0.10.2"
- "flask_cors==3.0.10"
- "flask==1.1.2"
- "numpy==1.19.4"
- "pillow==9.1.0"
- "pydub==0.25.1"
- "scipy==1.6.3"
- "torch==1.13.0"
- "torchaudio==0.13.0"
- "transformers==4.25.1"
# commands run after the environment is setup
# run:
# - "echo env is ready!"
# - "echo another command if needed"
# predict.py defines how predictions are run on your model
predict: "integrations/cog_riffusion.py:RiffusionPredictor"

7
dev_requirements.txt Normal file
View File

@ -0,0 +1,7 @@
black
ipdb
mypy
ruff
types-Flask-Cors
types-Pillow
types-requests

23
integrations/README.md Normal file
View File

@ -0,0 +1,23 @@
# Integrations
This package contains integrations of Riffusion into third party apps and deployments.
## Baseten
[Baseten](https://baseten.com) is a platform for building and deploying machine learning models.
## Replicate
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
download the model weights:
cog run python -m integrations.cog_riffusion --download_weights
Then you can run predictions:
cog predict -i prompt_a="funky synth solo"
You can also view the model on replicate [here](https://replicate.com/hmartiro/riffusion). Owners
can push an updated version of the model like so:
cog push r8.im/hmartiro/riffusion

0
integrations/__init__.py Normal file
View File

83
integrations/baseten.py Normal file
View File

@ -0,0 +1,83 @@
"""
This file can be used to build a Truss for deployment with Baseten.
If used, it should be renamed to model.py and placed alongside the other
files from /riffusion in the standard /model directory of the Truss.
For more on the Truss file format, see https://truss.baseten.co/
"""
import typing as T
import dacite
import torch
from huggingface_hub import snapshot_download
from riffusion.datatypes import InferenceInput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.server import compute_request
class Model:
"""
Baseten Truss model class for riffusion.
See: https://truss.baseten.co/reference/structure#model.py
"""
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._pipeline = None
self._vae = None
self.checkpoint_name = "riffusion/riffusion-model-v1"
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
def load(self):
"""
Load the model. Guaranteed to be called before `predict`.
"""
self._pipeline = RiffusionPipeline.load_checkpoint(
checkpoint=self.checkpoint_name,
use_traced_unet=True,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def preprocess(self, request: T.Dict) -> T.Dict:
"""
Incorporate pre-processing required by the model if desired here.
These might be feature transformations that are tightly coupled to the model.
"""
return request
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
"""
This is the main function that is called.
"""
assert self._pipeline is not None, "Model pipeline not loaded"
try:
inputs = dacite.from_dict(InferenceInput, request)
except dacite.exceptions.WrongTypeError as exception:
return str(exception), 400
except dacite.exceptions.MissingValueError as exception:
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = compute_request(
inputs=inputs,
pipeline=self._pipeline,
seed_images_dir=self._seed_images_dir,
)
return response
def postprocess(self, request: T.Dict) -> T.Dict:
"""
Incorporate post-processing required by the model if desired here.
"""
return request

View File

@ -0,0 +1,158 @@
"""
Prediction interface for Cog
https://github.com/replicate/cog/blob/main/docs/python.md
"""
import argparse
import os
import shutil
import typing as T
import numpy as np
import PIL
import torch
from cog import BaseModel, BasePredictor, Input, Path
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
MODEL_ID = "riffusion/riffusion-model-v1"
MODEL_CACHE = "riffusion-cache"
# Where built-in seed images are stored
SEED_IMAGES_DIR = Path("./seed_images")
SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val]
SEED_IMAGES.sort()
class Output(BaseModel):
"""
Output class for riffusion predictions
"""
audio: Path
spectrogram: Path
error: T.Optional[str] = None
class RiffusionPredictor(BasePredictor):
"""
Implementation of cog predictor object s.t. we can run riffusion predictions w/cog.
See README & https://github.com/replicate/cog for details
"""
def setup(self, local_files_only=True):
"""
Loads the model onto GPU from local cache.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = RiffusionPipeline.load_checkpoint(
checkpoint=MODEL_ID,
use_traced_unet=True,
device=self.device,
local_files_only=local_files_only,
cache_dir=MODEL_CACHE,
)
def predict(
self,
prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"),
denoising: float = Input(
description="How much to transform input spectrogram",
default=0.75,
ge=0,
le=1,
),
prompt_b: str = Input(
description="The second prompt to interpolate with the first,"
"leave blank if no interpolation",
default=None,
),
alpha: float = Input(
description="Interpolation alpha if using two prompts."
"A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully",
default=0.5,
ge=0,
le=1,
),
num_inference_steps: int = Input(
description="Number of steps to run the diffusion model", default=50, ge=1
),
seed_image_id: str = Input(
description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES
),
) -> Output:
"""
Runs riffusion inference.
"""
# Load the seed image by ID
init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png")
if not init_image_path.is_file():
return Output(error=f"Invalid seed image: {seed_image_id}")
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# fake max ints
seed_a = np.random.randint(0, 2147483647)
seed_b = np.random.randint(0, 2147483647)
start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising)
if not prompt_b: # no transition
prompt_b = prompt_a
alpha = 0
end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising)
riffusion_input = InferenceInput(
start=start,
end=end,
alpha=alpha,
num_inference_steps=num_inference_steps,
seed_image_id=seed_image_id,
)
# Execute the model to get the spectrogram image
image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None)
# Reconstruct audio from the image
params = SpectrogramParams()
converter = SpectrogramImageConverter(params=params, device=self.device)
segment = converter.audio_from_spectrogram_image(image)
if not os.path.exists("out/"):
os.mkdir("out")
out_img_path = "out/spectrogram.jpg"
image.save("out/spectrogram.jpg", exif=image.getexif())
out_wav_path = "out/gen_sound.wav"
segment.export(out_wav_path, format="wav")
return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path))
# TODO(hayk): Can we get rid of the below functions and incorporate into
# RiffusionPipeline.load_checkpoint?
def download_weights():
"""
Clears local cache & downloads riffusion weights
"""
if os.path.exists(MODEL_CACHE):
shutil.rmtree(MODEL_CACHE)
os.makedirs(MODEL_CACHE)
pred = RiffusionPredictor()
pred.setup(local_files_only=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--download_weights", action="store_true", help="Download and cache weights"
)
args = parser.parse_args()
if args.download_weights:
download_weights()

99
pyproject.toml Normal file
View File

@ -0,0 +1,99 @@
[tool.black]
line-length = 100
[tool.ruff]
line-length = 100
# Which rules to run
select = [
# Pyflakes
"F",
# Pycodestyle
"E",
"W",
# isort
"I001"
]
ignore = []
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
per-file-ignores = {}
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
# Assume Python 3.10.
target-version = "py310"
[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10
[tool.mypy]
python_version = "3.10"
[[tool.mypy.overrides]]
module = "argh.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "cog.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "diffusers.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "lora_diffusion.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "numpy.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "plotly.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "pydub.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.fft.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.io.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "torchaudio.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "transformers.*"
ignore_missing_imports = true

21
requirements.txt Normal file
View File

@ -0,0 +1,21 @@
accelerate
argh
dacite
demucs
diffusers>=0.9.0
flask
flask_cors
numpy
pillow>=9.1.0
plotly
pydub
pysoundfile
scipy
soundfile
sox
streamlit>=1.10.0
torch
torchaudio
torchvision
transformers
git+https://github.com/cloneofsimo/lora.git

0
riffusion/__init__.py Normal file
View File

187
riffusion/audio_splitter.py Normal file
View File

@ -0,0 +1,187 @@
import shutil
import subprocess
import tempfile
import typing as T
from pathlib import Path
import numpy as np
import pydub
import torch
import torchaudio
from torchaudio.transforms import Fade
from riffusion.util import audio_util
def split_audio(
segment: pydub.AudioSegment,
model_name: str = "htdemucs_6s",
extension: str = "wav",
jobs: int = 4,
device: str = "cuda",
) -> T.Dict[str, pydub.AudioSegment]:
"""
Split audio into stems using demucs.
"""
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
# Save the audio to a temporary file
audio_path = tmp_dir / "audio.mp3"
segment.export(audio_path, format="mp3")
# Assemble command
command = [
"demucs",
str(audio_path),
"--name",
model_name,
"--out",
str(tmp_dir),
"--jobs",
str(jobs),
"--device",
device if device != "mps" else "cpu",
]
print(" ".join(command))
if extension == "mp3":
command.append("--mp3")
# Run demucs
subprocess.run(
command,
check=True,
)
# Load the stems
stems = {}
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
stem = pydub.AudioSegment.from_file(stem_path)
stems[stem_path.stem] = stem
# Delete tmp dir
shutil.rmtree(tmp_dir)
return stems
class AudioSplitter:
"""
Split audio into instrument stems like {drums, bass, vocals, etc.}
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
model in the demucs repo. See the function above. Probably just delete this.
See:
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
"""
def __init__(
self,
segment_length_s: float = 10.0,
overlap_s: float = 0.1,
device: str = "cuda",
):
self.segment_length_s = segment_length_s
self.overlap_s = overlap_s
self.device = device
self.model = self.load_model().to(device)
@staticmethod
def load_model(model_path: str = "models/hdemucs_high_trained.pt") -> torchaudio.models.HDemucs:
"""
Load the trained HDEMUCS pytorch model.
"""
# NOTE(hayk): The sources are baked into the pretrained model and can't be changed
model = torchaudio.models.hdemucs_high(sources=["drums", "bass", "other", "vocals"])
path = torchaudio.utils.download_asset(model_path)
state_dict = torch.load(path)
model.load_state_dict(state_dict)
model.eval()
return model
def split(self, audio: pydub.AudioSegment) -> T.Dict[str, pydub.AudioSegment]:
"""
Split the given audio segment into instrument stems.
"""
if audio.channels == 1:
audio_stereo = audio.set_channels(2)
elif audio.channels == 2:
audio_stereo = audio
else:
raise ValueError(f"Audio must be stereo, but got {audio.channels} channels")
# Get as (samples, channels) float numpy array
waveform_np = np.array(audio_stereo.get_array_of_samples())
waveform_np = waveform_np.reshape(-1, audio_stereo.channels)
waveform_np_float = waveform_np.astype(np.float32)
# To torch and channels-first
waveform = torch.from_numpy(waveform_np_float).to(self.device)
waveform = waveform.transpose(1, 0)
# Normalize
ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()
# Split
sources = self.separate_sources(
waveform[None],
sample_rate=audio.frame_rate,
)[0]
# De-normalize
sources = sources * ref.std() + ref.mean()
# To numpy
sources_np = sources.cpu().numpy().astype(waveform_np.dtype)
# Convert to pydub
stem_segments = [
audio_util.audio_from_waveform(waveform, audio.frame_rate) for waveform in sources_np
]
# Convert back to mono if necessary
if audio.channels == 1:
stem_segments = [stem.set_channels(1) for stem in stem_segments]
return dict(zip(self.model.sources, stem_segments))
def separate_sources(
self,
waveform: torch.Tensor,
sample_rate: int = 44100,
):
"""
Apply model to a given waveform in chunks. Use fade and overlap to smooth the edges.
"""
batch, channels, length = waveform.shape
chunk_len = int(sample_rate * self.segment_length_s * (1 + self.overlap_s))
start = 0
end = chunk_len
overlap_frames = self.overlap_s * sample_rate
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
final = torch.zeros(batch, len(self.model.sources), channels, length, device=self.device)
# TODO(hayk): Improve this code, which came from the torchaudio docs
while start < length - overlap_frames:
chunk = waveform[:, :, start:end]
with torch.no_grad():
out = self.model.forward(chunk)
out = fade(out)
final[:, :, :, start:end] += out
if start == 0:
fade.fade_in_len = int(overlap_frames)
start += int(chunk_len - overlap_frames)
else:
start += chunk_len
end += chunk_len
if end >= length:
fade.fade_out_len = 0
return final

138
riffusion/cli.py Normal file
View File

@ -0,0 +1,138 @@
"""
Command line tools for riffusion.
"""
from pathlib import Path
import argh
import numpy as np
import pydub
from PIL import Image
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
def audio_to_image(
*,
audio: str,
image: str,
step_size_ms: int = 10,
num_frequencies: int = 512,
min_frequency: int = 0,
max_frequency: int = 10000,
window_duration_ms: int = 100,
padded_duration_ms: int = 400,
power_for_image: float = 0.25,
stereo: bool = False,
device: str = "cuda",
):
"""
Compute a spectrogram image from a waveform.
"""
segment = pydub.AudioSegment.from_file(audio)
params = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=stereo,
window_duration_ms=window_duration_ms,
padded_duration_ms=padded_duration_ms,
step_size_ms=step_size_ms,
min_frequency=min_frequency,
max_frequency=max_frequency,
num_frequencies=num_frequencies,
power_for_image=power_for_image,
)
converter = SpectrogramImageConverter(params=params, device=device)
pil_image = converter.spectrogram_image_from_audio(segment)
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
print(f"Wrote {image}")
def print_exif(*, image: str) -> None:
"""
Print the params of a spectrogram image as saved in the exif data.
"""
pil_image = Image.open(image)
exif_data = image_util.exif_from_image(pil_image)
for name, value in exif_data.items():
print(f"{name:<20} = {value:>15}")
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
"""
Reconstruct an audio clip from a spectrogram image.
"""
pil_image = Image.open(image)
# Get parameters from image exif
img_exif = pil_image.getexif()
assert img_exif is not None
try:
params = SpectrogramParams.from_exif(exif=img_exif)
except (KeyError, AttributeError):
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
converter = SpectrogramImageConverter(params=params, device=device)
segment = converter.audio_from_spectrogram_image(pil_image)
extension = Path(audio).suffix[1:]
segment.export(audio, format=extension)
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
def sample_clips(
*,
audio: str,
output_dir: str,
num_clips: int = 1,
duration_ms: int = 5000,
mono: bool = False,
extension: str = "wav",
seed: int = -1,
):
"""
Slice an audio file into clips of the given duration.
"""
if seed >= 0:
np.random.seed(seed)
segment = pydub.AudioSegment.from_file(audio)
if mono:
segment = segment.set_channels(1)
output_dir_path = Path(output_dir)
if not output_dir_path.exists():
output_dir_path.mkdir(parents=True)
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
clip_path = output_dir_path / clip_name
clip.export(clip_path, format=extension)
print(f"Wrote {clip_path}")
if __name__ == "__main__":
argh.dispatch_commands(
[
audio_to_image,
image_to_audio,
sample_clips,
print_exif,
]
)

73
riffusion/datatypes.py Normal file
View File

@ -0,0 +1,73 @@
"""
Data model for the riffusion API.
"""
from __future__ import annotations
import typing as T
from dataclasses import dataclass
@dataclass(frozen=True)
class PromptInput:
"""
Parameters for one end of interpolation.
"""
# Text prompt fed into a CLIP model
prompt: str
# Random seed for denoising
seed: int
# Negative prompt to avoid (optional)
negative_prompt: T.Optional[str] = None
# Denoising strength
denoising: float = 0.75
# Classifier-free guidance strength
guidance: float = 7.0
@dataclass(frozen=True)
class InferenceInput:
"""
Parameters for a single run of the riffusion model, interpolating between
a start and end set of PromptInputs. This is the API required for a request
to the model server.
"""
# Start point of interpolation
start: PromptInput
# End point of interpolation
end: PromptInput
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
# uses end fully.
alpha: float
# Number of inner loops of the diffusion model
num_inference_steps: int = 50
# Which seed image to use
seed_image_id: str = "og_beat"
# ID of mask image to use
mask_image_id: T.Optional[str] = None
@dataclass(frozen=True)
class InferenceOutput:
"""
Response from the model inference server.
"""
# base64 encoded spectrogram image as a JPEG
image: str
# base64 encoded audio clip as an MP3
audio: str
# The duration of the audio clip
duration_s: float

3
riffusion/external/README.md vendored Normal file
View File

@ -0,0 +1,3 @@
# external
This package contains scripts and tools from external sources.

0
riffusion/external/__init__.py vendored Normal file
View File

0
riffusion/external/lora/__init__.py vendored Normal file
View File

View File

@ -0,0 +1,24 @@
export MODEL_NAME="riffusion/riffusion-model-v1"
export INSTANCE_DIR="/tmp/sample_clips_tdlcqdfi/images"
export OUTPUT_DIR="/home/ubuntu/lora_dreambooth_waterfalls_2k"
accelerate launch\
--num_machines 1 \
--num_processes 8 \
--dynamo_backend=no \
--mixed_precision="fp16" \
riffusion/external/lora/train_lora_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="style of sks" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000
# TODO try mixed_precision=fp16
# TODO try num_processes = 8

49
riffusion/external/lora/train_lora.py vendored Normal file
View File

@ -0,0 +1,49 @@
from lora_diffusion.cli_lora_pti import train
from lora_diffusion.dataset import STYLE_TEMPLATE
MODEL_NAME = "riffusion/riffusion-model-v1"
INSTANCE_DIR = "/tmp/sample_clips_xzv8p57g/images"
OUTPUT_DIR = "./lora_output_acoustic"
if __name__ == "__main__":
entries = [
"music in the style of {}",
"sound in the style of {}",
"vibe in the style of {}",
"audio in the style of {}",
"groove in the style of {}",
]
for i in range(len(STYLE_TEMPLATE)):
STYLE_TEMPLATE[i] = entries[i % len(entries)]
print(STYLE_TEMPLATE)
train(
pretrained_model_name_or_path=MODEL_NAME,
instance_data_dir=INSTANCE_DIR,
output_dir=OUTPUT_DIR,
train_text_encoder=True,
resolution=512,
train_batch_size=1,
gradient_accumulation_steps=4,
scale_lr=True,
learning_rate_unet=1e-4,
learning_rate_text=1e-5,
learning_rate_ti=5e-4,
color_jitter=False,
lr_scheduler="linear",
lr_warmup_steps=0,
placeholder_tokens="<s1>|<s2>",
use_template="style",
save_steps=100,
max_train_steps_ti=1000,
max_train_steps_tuning=1000,
perform_inversion=True,
clip_ti_decay=True,
weight_decay_ti=0.000,
weight_decay_lora=0.001,
continue_inversion=True,
continue_inversion_lr=1e-4,
device="cuda:0",
lora_rank=1,
use_face_segmentation_condition=False,
)

37
riffusion/external/lora/train_lora.sh vendored Executable file
View File

@ -0,0 +1,37 @@
export MODEL_NAME="riffusion/riffusion-model-v1"
export INSTANCE_DIR="/tmp/sample_clips_xzv8p57g/images"
export OUTPUT_DIR="./lora_output_acoustic"
lora_pti \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--train_text_encoder \
--resolution=512 \
# Started as 1
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--scale_lr \
--learning_rate_unet=1e-4 \
--learning_rate_text=1e-5 \
--learning_rate_ti=5e-4 \
# --color_jitter \
--lr_scheduler="linear" \
--lr_warmup_steps=0 \
--placeholder_tokens="<s>" \
# initializer tokens
# class prompt
# --use_template="style"\
--save_steps=100 \
--max_train_steps_ti=1000 \
--max_train_steps_tuning=1000 \
--perform_inversion=True \
--clip_ti_decay \
--weight_decay_ti=0.000 \
--weight_decay_lora=0.001\
--continue_inversion \
--continue_inversion_lr=1e-4 \
--device="cuda:0" \
# 1 or 4?
--lora_rank=4 \
# --use_face_segmentation_condition\

View File

@ -0,0 +1,958 @@
# Bootstrapped from:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
# ruff: noqa
# mypy: ignore-errors
import argparse
import hashlib
import inspect
import itertools
import math
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from lora_diffusion import (
extract_lora_ups_down,
inject_trainable_lora,
safetensors_available,
save_lora_weight,
save_safeloras,
)
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
color_jitter=False,
h_flip=False,
resize=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
img_transforms = []
if resize:
img_transforms.append(
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
)
if center_crop:
img_transforms.append(transforms.CenterCrop(size))
if color_jitter:
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
if h_flip:
img_transforms.append(transforms.RandomHorizontalFlip())
self.image_transforms = transforms.Compose(
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
return example
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
logger = get_logger(__name__)
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_vae_name_or_path",
type=str,
default=None,
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--prior_loss_weight",
type=float,
default=1.0,
help="The weight of prior preservation loss.",
)
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--output_format",
type=str,
choices=["pt", "safe", "both"],
default="both",
help="The output format of the model predicitions and checkpoints.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution",
)
parser.add_argument(
"--color_jitter",
action="store_true",
help="Whether to apply color jitter to images",
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=4,
help="Batch size (per device) for sampling images.",
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save checkpoint every X updates steps.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="Rank of LoRA approximation.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=None,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--learning_rate_text",
type=float,
default=5e-6,
help="Initial learning rate for text encoder (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the model to the Hub.",
)
parser.add_argument(
"--hub_token",
type=str,
default=None,
help="The token to use to push to the Model Hub.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--resume_unet",
type=str,
default=None,
help=("File path for unet lora to resume training."),
)
parser.add_argument(
"--resume_text_encoder",
type=str,
default=None,
help=("File path for text encoder lora to resume training."),
)
parser.add_argument(
"--resize",
type=bool,
default=True,
required=False,
help="Should images be resized to --resolution before training?",
)
parser.add_argument(
"--use_xformers", action="store_true", help="Whether or not to use xformers"
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
if args.class_data_dir is not None:
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
if not safetensors_available:
if args.output_format == "both":
print(
"Safetensors is not available - changing output format to just output PyTorch files"
)
args.output_format = "pt"
elif args.output_format == "safe":
raise ValueError(
"Safetensors is not available - either install it, or change output_format."
)
return args
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if (
args.train_text_encoder
and args.gradient_accumulation_steps > 1
and accelerator.num_processes > 1
):
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None:
set_seed(args.seed)
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(
sample_dataset, batch_size=args.sample_batch_size
)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process,
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = (
class_images_dir
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
)
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
subfolder=None if args.pretrained_vae_name_or_path else "vae",
revision=None if args.pretrained_vae_name_or_path else args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
unet.requires_grad_(False)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet)
for _up, _down in extract_lora_ups_down(unet):
print("Before training: Unet First Layer lora up", _up.weight.data)
print("Before training: Unet First Layer lora down", _down.weight.data)
break
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if args.train_text_encoder:
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder,
target_replace_module=["CLIPAttention"],
r=args.lora_rank,
)
for _up, _down in extract_lora_ups_down(
text_encoder, target_replace_module=["CLIPAttention"]
):
print("Before training: text encoder First Layer lora up", _up.weight.data)
print("Before training: text encoder First Layer lora down", _down.weight.data)
break
if args.use_xformers:
set_use_memory_efficient_attention_xformers(unet, True)
set_use_memory_efficient_attention_xformers(vae, True)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
if args.scale_lr:
args.learning_rate = (
args.learning_rate
* args.gradient_accumulation_steps
* args.train_batch_size
* accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
text_lr = args.learning_rate if args.learning_rate_text is None else args.learning_rate_text
params_to_optimize = (
[
{"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
{
"params": itertools.chain(*text_encoder_lora_params),
"lr": text_lr,
},
]
if args.train_text_encoder
else itertools.chain(*unet_lora_params)
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
color_jitter=args.color_jitter,
resize=args.resize,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=1,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
if args.train_text_encoder:
(
unet,
text_encoder,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth", config=vars(args))
# Train!
total_batch_size = (
args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
)
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num batches each epoch = {len(train_dataloader)}")
print(f" Num Epochs = {args.num_train_epochs}")
print(f" Instantaneous batch size per device = {args.train_batch_size}")
print(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
last_save = 0
for epoch in range(args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = (
F.mse_loss(model_pred.float(), target.float(), reduction="none")
.mean([1, 2, 3])
.mean()
)
# Compute prior loss
prior_loss = F.mse_loss(
model_pred_prior.float(), target_prior.float(), reduction="mean"
)
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
progress_bar.update(1)
optimizer.zero_grad()
global_step += 1
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
revision=args.revision,
)
filename_unet = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
filename_text_encoder = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
)
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)
for _up, _down in extract_lora_ups_down(pipeline.unet):
print(
"First Unet Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
)
break
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break
last_save = global_step
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
print("\n\nLora TRAINING DONE!\n\n")
if args.output_format == "pt" or args.output_format == "both":
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
args.output_dir + "/lora_weight.text_encoder.pt",
target_replace_module=["CLIPAttention"],
)
if args.output_format == "safe" or args.output_format == "both":
loras = {}
loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"})
if args.train_text_encoder:
loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"})
save_safeloras(loras, args.output_dir + "/lora_weight.safetensors")
if args.push_to_hub:
repo.push_to_hub(
commit_message="End of training",
blocking=False,
auto_lfs_prune=True,
)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)

372
riffusion/external/prompt_weighting.py vendored Normal file
View File

@ -0,0 +1,372 @@
"""
This code is taken from the diffusers community pipeline:
https://github.com/huggingface/diffusers/blob/f242eba4fdc5b76dc40d3a9c01ba49b2c74b9796/examples/community/lpw_stable_diffusion.py
License: Apache 2.0
"""
# ruff: noqa
# mypy: ignore-errors
import logging
import re
import typing as T
import torch
from diffusers import StableDiffusionPipeline
logger = logging.getLogger(__name__)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning(
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
)
return tokens, weights
def pad_tokens_and_weights(
tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77
):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][
j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))
]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: T.Optional[bool] = True,
) -> torch.FloatTensor:
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
return text_embeddings
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: T.Union[str, T.List[str]],
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
max_embeddings_multiples: T.Optional[int] = 3,
no_boseos_middle: T.Optional[bool] = False,
skip_parsing: T.Optional[bool] = False,
skip_weighting: T.Optional[bool] = False,
**kwargs,
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `T.List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `T.List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(
pipe, uncond_prompt, max_length - 2
)
else:
prompt_tokens = [
token[1:-1]
for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = torch.tensor(
uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device
)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None

View File

@ -0,0 +1,477 @@
"""
Riffusion inference pipeline.
"""
from __future__ import annotations
import dataclasses
import functools
import inspect
import typing as T
import numpy as np
import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from riffusion.datatypes import InferenceInput
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
from riffusion.util import torch_util
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class RiffusionPipeline(DiffusionPipeline):
"""
Diffusers pipeline for doing a controlled img2img interpolation for audio generation.
# TODO(hayk): Document more
Part of this code was adapted from the non-img2img interpolation pipeline at:
https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py
Check the documentation for DiffusionPipeline for full information.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: T.Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@classmethod
def load_checkpoint(
cls,
checkpoint: str,
use_traced_unet: bool = True,
channels_last: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
local_files_only: bool = False,
low_cpu_mem_usage: bool = False,
cache_dir: T.Optional[str] = None,
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
Args:
checkpoint: Model checkpoint on disk in diffusers format
use_traced_unet: Whether to use the traced unet for speedups
device: Device to load the model on
channels_last: Whether to use channels_last memory format
local_files_only: Don't download, only use local files
low_cpu_mem_usage: Attempt to use less memory on CPU
"""
device = torch_util.check_device(device)
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
# Disable the NSFW filter, causes incorrect false positives
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
safety_checker=lambda images, **kwargs: (images, False),
low_cpu_mem_usage=low_cpu_mem_usage,
local_files_only=local_files_only,
cache_dir=cache_dir,
).to(device)
if channels_last:
pipeline.unet.to(memory_format=torch.channels_last)
# Optionally load a traced unet
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
traced_unet = cls.load_traced_unet(
checkpoint=checkpoint,
subfolder="unet_traced",
filename="unet_traced.pt",
in_channels=pipeline.unet.in_channels,
dtype=dtype,
device=device,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
if traced_unet is not None:
pipeline.unet = traced_unet
model = pipeline.to(device)
return model
@staticmethod
def load_traced_unet(
checkpoint: str,
subfolder: str,
filename: str,
in_channels: int,
dtype: torch.dtype,
device: str = "cuda",
local_files_only=False,
cache_dir: T.Optional[str] = None,
) -> T.Optional[torch.nn.Module]:
"""
Load a traced unet from the huggingface hub. This can improve performance.
"""
if device == "cpu" or device.lower().startswith("mps"):
print("WARNING: Traced UNet only available for CUDA, skipping")
return None
# Download and load the traced unet
unet_file = hf_hub_download(
checkpoint,
subfolder=subfolder,
filename=filename,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
unet_traced = torch.jit.load(unet_file)
# Wrap it in a torch module
class TracedUNet(torch.nn.Module):
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
def __init__(self):
super().__init__()
self.in_channels = device
self.device = device
self.dtype = dtype
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return self.UNet2DConditionOutput(sample=sample)
return TracedUNet()
@property
def device(self) -> str:
return str(self.vae.device)
@functools.lru_cache()
def embed_text(self, text) -> torch.FloatTensor:
"""
Takes in text and turns it into text embeddings.
"""
text_input = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed
@functools.lru_cache()
def embed_text_weighted(self, text) -> torch.FloatTensor:
"""
Get text embedding with weights.
"""
return get_weighted_text_embeddings(
pipe=self,
prompt=text,
uncond_prompt=None,
max_embeddings_multiples=3,
no_boseos_middle=False,
skip_parsing=False,
skip_weighting=False,
)[0]
@torch.no_grad()
def riffuse(
self,
inputs: InferenceInput,
init_image: Image.Image,
mask_image: T.Optional[Image.Image] = None,
use_reweighting: bool = True,
) -> Image.Image:
"""
Runs inference using interpolation with both img2img and text conditioning.
Args:
inputs: Parameter dataclass
init_image: Image used for conditioning
mask_image: White pixels in the mask will be replaced by noise and therefore repainted,
while black pixels will be preserved. It will be converted to a single
channel (luminance) before use.
use_reweighting: Use prompt reweighting
"""
alpha = inputs.alpha
start = inputs.start
end = inputs.end
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
# TODO(hayk): Always generate the seed on CPU?
if self.device.lower().startswith("mps"):
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
else:
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
# Text encodings
if use_reweighting:
embed_start = self.embed_text_weighted(start.prompt)
embed_end = self.embed_text_weighted(end.prompt)
else:
embed_start = self.embed_text(start.prompt)
embed_end = self.embed_text(end.prompt)
text_embedding = embed_start + alpha * (embed_end - embed_start)
# Image latents
init_image_torch = preprocess_image(init_image).to(
device=self.device, dtype=embed_start.dtype
)
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
# result is so close no matter the seed that it doesn't really add variety.
if self.device.lower().startswith("mps"):
generator = torch.Generator(device="cpu").manual_seed(start.seed)
else:
generator = torch.Generator(device=self.device).manual_seed(start.seed)
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Prepare mask latent
mask: T.Optional[torch.Tensor] = None
if mask_image:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
device=self.device, dtype=embed_start.dtype
)
outputs = self.interpolate_img2img(
text_embeddings=text_embedding,
init_latents=init_latents,
mask=mask,
generator_a=generator_start,
generator_b=generator_end,
interpolate_alpha=alpha,
strength_a=start.denoising,
strength_b=end.denoising,
num_inference_steps=inputs.num_inference_steps,
guidance_scale=guidance_scale,
)
return outputs["images"][0]
@torch.no_grad()
def interpolate_img2img(
self,
text_embeddings: torch.Tensor,
init_latents: torch.Tensor,
generator_a: torch.Generator,
generator_b: torch.Generator,
interpolate_alpha: float,
mask: T.Optional[torch.Tensor] = None,
strength_a: float = 0.8,
strength_b: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
num_images_per_prompt: int = 1,
eta: T.Optional[float] = 0.0,
output_type: T.Optional[str] = "pil",
**kwargs,
):
"""
TODO
"""
batch_size = text_embeddings.shape[0]
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if negative_prompt is None:
uncond_tokens = [""]
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else:
uncond_tokens = negative_prompt
# max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(
batch_size * num_images_per_prompt, dim=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents_dtype = text_embeddings.dtype
strength = (1 - interpolate_alpha) * strength_a + interpolate_alpha * strength_b
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor(
[timesteps] * batch_size * num_images_per_prompt, device=self.device
)
# add noise to latents using the timesteps
noise_a = torch.randn(
init_latents.shape, generator=generator_a, device=self.device, dtype=latents_dtype
)
noise_b = torch.randn(
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
)
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
init_latents_orig = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents.clone()
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
init_latents_proper = self.scheduler.add_noise(
init_latents_orig, noise, torch.tensor([t])
)
# import ipdb; ipdb.set_trace()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
latents = 1.0 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return dict(images=image, latents=latents, nsfw_content_detected=False)
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""
Preprocess an image for the model.
"""
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.LANCZOS)
image_np = np.array(image).astype(np.float32) / 255.0
image_np = image_np[None].transpose(0, 3, 1, 2)
image_torch = torch.from_numpy(image_np)
return 2.0 * image_torch - 1.0
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
"""
Preprocess a mask for the model.
"""
# Convert to grayscale
mask = mask.convert("L")
# Resize to integer multiple of 32
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h))
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
# Convert to numpy array and rescale
mask_np = np.array(mask).astype(np.float32) / 255.0
# Tile and transpose
mask_np = np.tile(mask_np, (4, 1, 1))
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
# Invert to repaint white and keep black
mask_np = 1 - mask_np # repaint white, keep black
return torch.from_numpy(mask_np)

189
riffusion/server.py Normal file
View File

@ -0,0 +1,189 @@
"""
Flask server that serves the riffusion model as an API.
"""
import dataclasses
import io
import json
import logging
import time
import typing as T
from pathlib import Path
import dacite
import flask
import PIL
from flask_cors import CORS
from riffusion.datatypes import InferenceInput, InferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import base64_util
# Flask app with CORS
app = flask.Flask(__name__)
CORS(app)
# Log at the INFO level to both stdout and disk
logging.basicConfig(level=logging.INFO)
logging.getLogger().addHandler(logging.FileHandler("server.log"))
# Global variable for the model pipeline
PIPELINE: T.Optional[RiffusionPipeline] = None
# Where built-in seed images are stored
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
def run_app(
*,
checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
device: str = "cuda",
host: str = "127.0.0.1",
port: int = 3013,
debug: bool = False,
ssl_certificate: T.Optional[str] = None,
ssl_key: T.Optional[str] = None,
):
"""
Run a flask API that serves the given riffusion model checkpoint.
"""
# Initialize the model
global PIPELINE
PIPELINE = RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=not no_traced_unet,
device=device,
)
args = dict(
debug=debug,
threaded=False,
host=host,
port=port,
)
if ssl_certificate:
assert ssl_key is not None
args["ssl_context"] = (ssl_certificate, ssl_key)
app.run(**args) # type: ignore
@app.route("/run_inference/", methods=["POST"])
def run_inference():
"""
Execute the riffusion model as an API.
Inputs:
Serialized JSON of the InferenceInput dataclass
Returns:
Serialized JSON of the InferenceOutput dataclass
"""
start_time = time.time()
# Parse the payload as JSON
json_data = json.loads(flask.request.data)
# Log the request
logging.info(json_data)
# Parse an InferenceInput dataclass from the payload
try:
inputs = dacite.from_dict(InferenceInput, json_data)
except dacite.exceptions.WrongTypeError as exception:
logging.info(json_data)
return str(exception), 400
except dacite.exceptions.MissingValueError as exception:
logging.info(json_data)
return str(exception), 400
response = compute_request(
inputs=inputs,
seed_images_dir=SEED_IMAGES_DIR,
pipeline=PIPELINE,
)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
return response
def compute_request(
inputs: InferenceInput,
pipeline: RiffusionPipeline,
seed_images_dir: str,
) -> T.Union[str, T.Tuple[str, int]]:
"""
Does all the heavy lifting of the request.
Args:
inputs: The input dataclass
pipeline: The riffusion model pipeline
seed_images_dir: The directory where seed images are stored
"""
# Load the seed image by ID
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
mask_image: T.Optional[PIL.Image.Image] = None
if inputs.mask_image_id:
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
# Execute the model to get the spectrogram image
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=mask_image,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct audio from the image
# TODO(hayk): It may help performance a bit to cache this object
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
segment = converter.audio_from_spectrogram_image(
image,
apply_filters=True,
)
# Export audio to MP3 bytes
mp3_bytes = io.BytesIO()
segment.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
# Export image to JPEG bytes
image_bytes = io.BytesIO()
image.save(image_bytes, exif=image.getexif(), format="JPEG")
image_bytes.seek(0)
# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
duration_s=segment.duration_seconds,
)
return json.dumps(dataclasses.asdict(output))
if __name__ == "__main__":
import argh
argh.dispatch_command(run_app)

View File

@ -0,0 +1,204 @@
import warnings
import numpy as np
import pydub
import torch
import torchaudio
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import audio_util, torch_util
class SpectrogramConverter:
"""
Convert between audio segments and spectrogram tensors using torchaudio.
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
used in some functions is "mel amplitudes".
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
returns the amplitude, because the phase is chaotic and hard to learn. The function
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
approximates the phase information using the Griffin-Lim algorithm.
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
equal to the number of channels in the input audio segment.
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
For more information, see https://pytorch.org/audio/stable/transforms.html
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = torch_util.check_device(device)
if device.lower().startswith("mps"):
warnings.warn(
"WARNING: MPS does not support audio operations, falling back to CPU for them",
stacklevel=2,
)
self.device = "cpu"
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
self.spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=params.n_fft,
hop_length=params.hop_length,
win_length=params.win_length,
pad=0,
window_fn=torch.hann_window,
power=None,
normalized=False,
wkwargs=None,
center=True,
pad_mode="reflect",
onesided=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
n_fft=params.n_fft,
n_iter=params.num_griffin_lim_iters,
win_length=params.win_length,
hop_length=params.hop_length,
window_fn=torch.hann_window,
power=1.0,
wkwargs=None,
momentum=0.99,
length=None,
rand_init=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
self.mel_scaler = torchaudio.transforms.MelScale(
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
n_stft=params.n_fft // 2 + 1,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=params.n_fft // 2 + 1,
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
max_iter=params.max_mel_iters,
tolerance_loss=1e-5,
tolerance_change=1e-8,
sgdargs=None,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
def spectrogram_from_audio(
self,
audio: pydub.AudioSegment,
) -> np.ndarray:
"""
Compute a spectrogram from an audio segment.
Args:
audio: Audio segment which must match the sample rate of the params
Returns:
spectrogram: (channel, frequency, time)
"""
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
# Get the samples as a numpy array in (batch, samples) shape
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
# Convert to floats if necessary
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
waveform_tensor = torch.from_numpy(waveform).to(self.device)
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
return amplitudes_mel.cpu().numpy()
def audio_from_spectrogram(
self,
spectrogram: np.ndarray,
apply_filters: bool = True,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram.
Args:
spectrogram: (batch, frequency, time)
apply_filters: Post-process with normalization and compression
Returns:
audio: Audio segment with channels equal to the batch dimension
"""
# Move to device
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
# Reconstruct the waveform
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
# Convert to audio segment
segment = audio_util.audio_from_waveform(
samples=waveform.cpu().numpy(),
sample_rate=self.p.sample_rate,
# Normalize the waveform to the range [-1, 1]
normalize=True,
)
# Optionally apply post-processing filters
if apply_filters:
segment = audio_util.apply_filters(
segment,
compression=False,
)
return segment
def mel_amplitudes_from_waveform(
self,
waveform: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to compute Mel-scale amplitudes from a waveform.
Args:
waveform: (batch, samples)
Returns:
amplitudes_mel: (batch, frequency, time)
"""
# Compute the complex-valued spectrogram
spectrogram_complex = self.spectrogram_func(waveform)
# Take the magnitude
amplitudes = torch.abs(spectrogram_complex)
# Convert to mel scale
return self.mel_scaler(amplitudes)
def waveform_from_mel_amplitudes(
self,
amplitudes_mel: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
Args:
amplitudes_mel: (batch, frequency, time)
Returns:
waveform: (batch, samples)
"""
# Convert from mel scale to linear
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
# Run the approximate algorithm to compute the phase and recover the waveform
return self.inverse_spectrogram_func(amplitudes_linear)

View File

@ -0,0 +1,91 @@
import numpy as np
import pydub
from PIL import Image
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
class SpectrogramImageConverter:
"""
Convert between spectrogram images and audio segments.
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
to images and back. The real audio processing lives in SpectrogramConverter.
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = device
self.converter = SpectrogramConverter(params=params, device=device)
def spectrogram_image_from_audio(
self,
segment: pydub.AudioSegment,
) -> Image.Image:
"""
Compute a spectrogram image from an audio segment.
Args:
segment: Audio segment to convert
Returns:
Spectrogram image (in pillow format)
"""
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
if self.p.stereo:
if segment.channels == 1:
print("WARNING: Mono audio but stereo=True, cloning channel")
segment = segment.set_channels(2)
elif segment.channels > 2:
print("WARNING: Multi channel audio, reducing to stereo")
segment = segment.set_channels(2)
else:
if segment.channels > 1:
print("WARNING: Stereo audio but stereo=False, setting to mono")
segment = segment.set_channels(1)
spectrogram = self.converter.spectrogram_from_audio(segment)
image = image_util.image_from_spectrogram(
spectrogram,
power=self.p.power_for_image,
)
# Store conversion params in exif metadata of the image
exif_data = self.p.to_exif()
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
exif = image.getexif()
exif.update(exif_data.items())
return image
def audio_from_spectrogram_image(
self,
image: Image.Image,
apply_filters: bool = True,
max_value: float = 30e6,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram image.
Args:
image: Spectrogram image (in pillow format)
apply_filters: Apply post-processing to improve the reconstructed audio
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
"""
spectrogram = image_util.spectrogram_from_image(
image,
max_value=max_value,
power=self.p.power_for_image,
stereo=self.p.stereo,
)
segment = self.converter.audio_from_spectrogram(
spectrogram,
apply_filters=apply_filters,
)
return segment

View File

@ -0,0 +1,115 @@
from __future__ import annotations
import typing as T
from dataclasses import dataclass
from enum import Enum
@dataclass(frozen=True)
class SpectrogramParams:
"""
Parameters for the conversion from audio to spectrograms to images and back.
Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
within spectrogram images.
To understand what these parameters do and to customize them, read `spectrogram_converter.py`
and the linked torchaudio documentation.
"""
# Whether the audio is stereo or mono
stereo: bool = False
# FFT parameters
sample_rate: int = 44100
step_size_ms: int = 10
window_duration_ms: int = 100
padded_duration_ms: int = 400
# Mel scale parameters
num_frequencies: int = 512
# TODO(hayk): Set these to [20, 20000] for newer models
min_frequency: int = 0
max_frequency: int = 10000
mel_scale_norm: T.Optional[str] = None
mel_scale_type: str = "htk"
max_mel_iters: int = 200
# Griffin Lim parameters
num_griffin_lim_iters: int = 32
# Image parameterization
power_for_image: float = 0.25
class ExifTags(Enum):
"""
Custom EXIF tags for the spectrogram image.
"""
SAMPLE_RATE = 11000
STEREO = 11005
STEP_SIZE_MS = 11010
WINDOW_DURATION_MS = 11020
PADDED_DURATION_MS = 11030
NUM_FREQUENCIES = 11040
MIN_FREQUENCY = 11050
MAX_FREQUENCY = 11060
POWER_FOR_IMAGE = 11070
MAX_VALUE = 11080
@property
def n_fft(self) -> int:
"""
The number of samples in each STFT window, with padding.
"""
return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
@property
def win_length(self) -> int:
"""
The number of samples in each STFT window.
"""
return int(self.window_duration_ms / 1000.0 * self.sample_rate)
@property
def hop_length(self) -> int:
"""
The number of samples between each STFT window.
"""
return int(self.step_size_ms / 1000.0 * self.sample_rate)
def to_exif(self) -> T.Dict[int, T.Any]:
"""
Return a dictionary of EXIF tags for the current values.
"""
return {
self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
self.ExifTags.STEREO.value: self.stereo,
self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
}
@classmethod
def from_exif(cls, exif: T.Mapping[int, T.Any]) -> SpectrogramParams:
"""
Create a SpectrogramParams object from the EXIF tags of the given image.
"""
# TODO(hayk): Handle missing tags
return cls(
sample_rate=exif[cls.ExifTags.SAMPLE_RATE.value],
stereo=bool(exif[cls.ExifTags.STEREO.value]),
step_size_ms=exif[cls.ExifTags.STEP_SIZE_MS.value],
window_duration_ms=exif[cls.ExifTags.WINDOW_DURATION_MS.value],
padded_duration_ms=exif[cls.ExifTags.PADDED_DURATION_MS.value],
num_frequencies=exif[cls.ExifTags.NUM_FREQUENCIES.value],
min_frequency=exif[cls.ExifTags.MIN_FREQUENCY.value],
max_frequency=exif[cls.ExifTags.MAX_FREQUENCY.value],
power_for_image=exif[cls.ExifTags.POWER_FOR_IMAGE.value],
)

View File

@ -0,0 +1,3 @@
# streamlit
This package is an interactive streamlit app for riffusion.

View File

View File

@ -0,0 +1,404 @@
import io
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
from PIL import Image
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
from riffusion.util import audio_util
def render_audio_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":wave: Audio to Audio")
st.write(
"""
Modify existing audio from a text prompt or interpolate between two.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool allows you to upload an audio file of arbitrary length and modify it with
a text prompt. It does this by sweeping over the audio in overlapping clips, doing
img2img style transfer with riffusion, then stitching the clips back together with
cross fading to eliminate seams.
Try a denoising strength of 0.4 for light modification and 0.55 for more heavy
modification. The best specific denoising depends on how different the prompt is
from the source audio. You can play with the seed to get infinite variations.
Currently the same seed is used for all clips along the track.
If the Interpolation check box is enabled, supports entering two sets of prompt,
seed, and denoising value and smoothly blends between them along the selected
duration of the audio. This is a great way to create a transition.
"""
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
lora_path = st.sidebar.text_input("Lora Path", "")
lora_scale = st.sidebar.number_input("Lora Scale", value=1.0)
with st.sidebar:
num_inference_steps = T.cast(
int,
st.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
guidance = st.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)
scheduler = st.selectbox(
"Scheduler",
options=streamlit_util.SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
audio_file = st.file_uploader(
"Upload audio",
type=streamlit_util.AUDIO_EXTENSIONS,
label_visibility="collapsed",
)
if not audio_file:
st.info("Upload audio to get started")
return
st.write("#### Original")
st.audio(audio_file)
segment = streamlit_util.load_audio_file(audio_file)
# TODO(hayk): Fix
if segment.frame_rate != 44100:
st.warning("Audio must be 44100Hz. Converting")
segment = segment.set_frame_rate(44100)
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
clip_p = get_clip_params()
start_time_s = clip_p["start_time_s"]
clip_duration_s = clip_p["clip_duration_s"]
overlap_duration_s = clip_p["overlap_duration_s"]
duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s)
increment_s = clip_duration_s - overlap_duration_s
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
write_clip_details(
clip_start_times=clip_start_times,
clip_duration_s=clip_duration_s,
overlap_duration_s=overlap_duration_s,
)
interpolate = st.checkbox(
"Interpolate between two endpoints",
value=False,
help="Interpolate between two prompts, seeds, or denoising values along the"
"duration of the segment",
)
counter = streamlit_util.StreamlitCounter()
with st.form("audio to audio form"):
if interpolate:
left, right = st.columns(2)
with left:
st.write("##### Prompt A")
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
elif use_magic_mix:
prompt = st.text_input("Prompt", key="prompt_a")
row = st.columns(4)
seed = T.cast(
int,
row[0].number_input(
"Seed",
value=42,
key="seed_a",
),
)
prompt_input_a = PromptInput(
prompt=prompt,
seed=seed,
guidance=guidance,
)
magic_mix_kmin = row[1].number_input("Kmin", value=0.3)
magic_mix_kmax = row[2].number_input("Kmax", value=0.5)
magic_mix_mix_factor = row[3].number_input("Mix Factor", value=0.5)
else:
prompt_input_a = PromptInput(
guidance=guidance,
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
)
st.form_submit_button("Riff", type="primary", on_click=counter.increment)
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)
clip_segments = slice_audio_into_clips(
segment=segment,
clip_start_times=clip_start_times,
clip_duration_s=clip_duration_s,
)
if not prompt_input_a.prompt:
st.info("Enter a prompt")
return
if counter.value == 0:
return
params = SpectrogramParams()
if interpolate:
# TODO(hayk): Make not linspace
alphas = list(np.linspace(0, 1, len(clip_segments)))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
result_images: T.List[Image.Image] = []
result_segments: T.List[pydub.AudioSegment] = []
for i, clip_segment in enumerate(clip_segments):
st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s")
audio_bytes = io.BytesIO()
clip_segment.export(audio_bytes, format="wav")
init_image = streamlit_util.spectrogram_image_from_audio(
clip_segment,
params=params,
device=device,
)
# TODO(hayk): Roll this into spectrogram_image_from_audio?
init_image_resized = scale_image_to_32_stride(init_image)
progress_callback = None
if show_clip_details:
left, right = st.columns(2)
left.write("##### Source Clip")
left.image(init_image, use_column_width=False)
left.audio(audio_bytes)
right.write("##### Riffed Clip")
empty_bin = right.empty()
with empty_bin.container():
st.info("Riffing...")
progress = st.progress(0.0)
progress_callback = progress.progress
if interpolate:
assert use_magic_mix is False, "Cannot use magic mix and interpolate together"
inputs = InferenceInput(
alpha=float(alphas[i]),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)
image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image_resized,
device=device,
)
elif use_magic_mix:
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
image = streamlit_util.run_img2img_magic_mix(
prompt=prompt_input_a.prompt,
init_image=init_image_resized,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
seed=prompt_input_a.seed,
kmin=magic_mix_kmin,
kmax=magic_mix_kmax,
mix_factor=magic_mix_mix_factor,
device=device,
scheduler=scheduler,
)
else:
image = streamlit_util.run_img2img(
prompt=prompt_input_a.prompt,
init_image=init_image_resized,
denoising_strength=prompt_input_a.denoising,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=prompt_input_a.negative_prompt,
seed=prompt_input_a.seed,
progress_callback=progress_callback,
device=device,
scheduler=scheduler,
lora_path=lora_path,
lora_scale=lora_scale,
)
# Resize back to original size
image = image.resize(init_image.size, Image.BICUBIC)
result_images.append(image)
if show_clip_details:
empty_bin.empty()
right.image(image, use_column_width=False)
riffed_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=image,
params=params,
device=device,
)
result_segments.append(riffed_segment)
audio_bytes = io.BytesIO()
riffed_segment.export(audio_bytes, format="wav")
if show_clip_details:
right.audio(audio_bytes)
if show_clip_details and show_difference:
diff_np = np.maximum(
0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32)
)
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=diff_image,
params=params,
device=device,
)
audio_bytes = io.BytesIO()
diff_segment.export(audio_bytes, format=extension)
st.audio(audio_bytes)
# Combine clips with a crossfade based on overlap
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
input_name = Path(audio_file.name).stem
output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}"
streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension)
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
"""
Render the parameters of slicing audio into clips.
"""
p: T.Dict[str, T.Any] = {}
cols = st.columns(4)
p["start_time_s"] = cols[0].number_input(
"Start Time [s]",
min_value=0.0,
value=0.0,
)
p["duration_s"] = cols[1].number_input(
"Duration [s]",
min_value=0.0,
value=15.0,
)
if advanced:
p["clip_duration_s"] = cols[2].number_input(
"Clip Duration [s]",
min_value=3.0,
max_value=10.0,
value=5.0,
)
else:
p["clip_duration_s"] = 5.0
if advanced:
p["overlap_duration_s"] = cols[3].number_input(
"Overlap Duration [s]",
min_value=0.0,
max_value=10.0,
value=0.2,
)
else:
p["overlap_duration_s"] = 0.2
return p
def write_clip_details(
clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float
):
"""
Write details of the clips to be sliced from an audio segment.
"""
clip_details_text = (
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s"
)
with st.expander(clip_details_text):
st.dataframe(
{
"Start Time [s]": clip_start_times,
"End Time [s]": clip_start_times + clip_duration_s,
"Duration [s]": clip_duration_s,
}
)
def slice_audio_into_clips(
segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float
) -> T.List[pydub.AudioSegment]:
"""
Slice an audio segment into a list of clips of a given duration at the given start times.
"""
clip_segments: T.List[pydub.AudioSegment] = []
for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
# TODO(hayk): I don't think this is working properly
if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
clip_segments.append(clip_segment)
return clip_segments
def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
"""
Scale an image to a size that is a multiple of 32.
"""
closest_width = int(np.ceil(image.width / 32) * 32)
closest_height = int(np.ceil(image.height / 32) * 32)
return image.resize((closest_width, closest_height), Image.BICUBIC)
if __name__ == "__main__":
render_audio_to_audio()

View File

@ -0,0 +1,74 @@
import dataclasses
from pathlib import Path
import streamlit as st
from PIL import Image
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.util.image_util import exif_from_image
def render_image_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":musical_keyboard: Image to Audio")
st.write(
"""
Reconstruct audio from spectrogram images.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool takes an existing spectrogram image and reconstructs it into an audio
waveform. It also displays the EXIF metadata stored inside the image, which can
contain the parameters used to create the spectrogram image. If no EXIF is contained,
assumes default parameters.
"""
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
image_file = st.file_uploader(
"Upload a file",
type=streamlit_util.IMAGE_EXTENSIONS,
label_visibility="collapsed",
)
if not image_file:
st.info("Upload an image file to get started")
return
image = Image.open(image_file)
st.image(image)
with st.expander("Image metadata", expanded=False):
exif = exif_from_image(image)
st.json(exif)
try:
params = SpectrogramParams.from_exif(exif=image.getexif())
except KeyError:
st.info("Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
with st.expander("Spectrogram Parameters", expanded=False):
st.json(dataclasses.asdict(params))
segment = streamlit_util.audio_segment_from_spectrogram_image(
image=image.copy(),
params=params,
device=device,
)
streamlit_util.display_and_download_audio(
segment,
name=Path(image_file.name).stem,
extension=extension,
)
if __name__ == "__main__":
render_image_to_audio()

View File

@ -0,0 +1,273 @@
import dataclasses
import io
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
from PIL import Image
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_interpolation() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":performing_arts: Interpolation")
st.write(
"""
Interpolate between prompts in the latent space.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool allows specifying two endpoints and generating a long-form interpolation
between them that traverses the latent space. The interpolation is generated by
the method described at https://www.riffusion.com/about. A seed image is used to
set the beat and tempo of the generated audio, and can be set in the sidebar.
Usually the seed is changed or the prompt, but not both at once. You can browse
infinite variations of the same prompt by changing the seed.
For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising.
This will generate a 50 second clip at 5 seconds per step. Then play with the seeds
or denoising to get different variations.
"""
)
# Sidebar params
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
num_interpolation_steps = T.cast(
int,
st.sidebar.number_input(
"Interpolation steps",
value=4,
min_value=1,
max_value=20,
help="Number of model generations between the two prompts. Controls the duration.",
),
)
num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
guidance = st.sidebar.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)
init_image_name = st.sidebar.selectbox(
"Seed image",
# TODO(hayk): Read from directory
options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
index=0,
help="Which seed image to use for img2img. Custom allows uploading your own.",
)
assert init_image_name is not None
if init_image_name == "custom":
init_image_file = st.sidebar.file_uploader(
"Upload a custom seed image",
type=streamlit_util.IMAGE_EXTENSIONS,
label_visibility="collapsed",
)
if init_image_file:
st.sidebar.image(init_image_file)
show_individual_outputs = st.sidebar.checkbox(
"Show individual outputs",
value=False,
help="Show each model output",
)
show_images = st.sidebar.checkbox(
"Show individual images",
value=False,
help="Show each generated image",
)
# Prompt inputs A and B in two columns
with st.form(key="interpolation_form"):
left, right = st.columns(2)
with left:
st.write("##### Prompt A")
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
st.form_submit_button("Generate", type="primary")
if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them")
return
alphas = list(np.linspace(0, 1, num_interpolation_steps))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
# TODO(hayk): Apply scaling to alphas like this
# T_shifted = T * 2 - 1
# T_sample = (np.abs(T_shifted)**t_scale_power * np.sign(T_shifted) + 1) / 2
# T_sample = T_sample * (t_end - t_start) + t_start
if init_image_name == "custom":
if not init_image_file:
st.info("Upload a custom seed image")
return
init_image = Image.open(init_image_file).convert("RGB")
else:
init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
)
init_image = Image.open(str(init_image_path)).convert("RGB")
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
image_list: T.List[Image.Image] = []
audio_bytes_list: T.List[io.BytesIO] = []
for i, alpha in enumerate(alphas):
inputs = InferenceInput(
alpha=float(alpha),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)
if i == 0:
with st.expander("Example input JSON", expanded=False):
st.json(dataclasses.asdict(inputs))
image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image,
device=device,
extension=extension,
)
if show_individual_outputs:
st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
if show_images:
st.image(image)
st.audio(audio_bytes)
image_list.append(image)
audio_bytes_list.append(audio_bytes)
st.write("#### Final Output")
# TODO(hayk): Concatenate with overlap and better blending like in audio to audio
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
concat_segment = audio_segments[0]
for segment in audio_segments[1:]:
concat_segment = concat_segment.append(segment, crossfade=0)
audio_bytes = io.BytesIO()
concat_segment.export(audio_bytes, format=extension)
audio_bytes.seek(0)
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
st.audio(audio_bytes)
output_name = (
f"{prompt_input_a.prompt.replace(' ', '_')}_"
f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
)
st.download_button(
output_name,
data=audio_bytes,
file_name=output_name,
mime=f"audio/{extension}",
)
def get_prompt_inputs(
key: str,
include_negative_prompt: bool = False,
cols: bool = False,
) -> T.Dict[str, T.Any]:
"""
Compute prompt inputs from widgets.
"""
p: T.Dict[str, T.Any] = {}
# Optionally use columns
left, right = T.cast(T.Any, st.columns(2) if cols else (st, st))
visibility = "visible" if cols else "collapsed"
p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}")
if include_negative_prompt:
p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}")
p["seed"] = T.cast(
int,
left.number_input(
"Seed",
value=42,
key=f"seed_{key}",
help="Integer used to generate a random result. Vary this to explore alternatives.",
),
)
p["denoising"] = right.number_input(
"Denoising",
value=0.5,
key=f"denoising_{key}",
help="How much to modify the seed image",
)
return p
@st.experimental_memo
def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
) -> T.Tuple[Image.Image, io.BytesIO]:
"""
Cached function for riffusion interpolation.
"""
pipeline = streamlit_util.load_riffusion_checkpoint(
device=device,
# No trace so we can have variable width
no_traced_unet=True,
)
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=None,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct from image to audio
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format=extension,
)
return image, audio_bytes
if __name__ == "__main__":
render_interpolation()

View File

@ -0,0 +1,131 @@
import tempfile
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_sample_clips() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":paperclip: Sample Clips")
st.write(
"""
Export short clips from an audio file.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool simply allows uploading an audio file and randomly sampling short clips
from it. It's useful for generating a large number of short clips from a single
audio file. Outputs can be saved to a given directory with a given audio extension.
"""
)
audio_file = st.file_uploader(
"Upload a file",
type=streamlit_util.AUDIO_EXTENSIONS,
label_visibility="collapsed",
)
if not audio_file:
st.info("Upload an audio file to get started")
return
st.audio(audio_file)
segment = pydub.AudioSegment.from_file(audio_file)
st.write(
" \n".join(
[
f"**Duration**: {segment.duration_seconds:.3f} seconds",
f"**Channels**: {segment.channels}",
f"**Sample rate**: {segment.frame_rate} Hz",
f"**Sample width**: {segment.sample_width} bytes",
]
)
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
save_to_disk = st.sidebar.checkbox("Save to Disk", False)
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
compute_spectrograms = st.sidebar.checkbox("Compute Spectrograms", False)
row = st.columns(4)
num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3))
duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000))
seed = T.cast(int, row[2].number_input("Seed", value=42))
counter = streamlit_util.StreamlitCounter()
st.button("Sample Clips", type="primary", on_click=counter.increment)
if counter.value == 0:
return
# Optionally pick an output directory
if save_to_disk:
output_dir = tempfile.mkdtemp(prefix="sample_clips_")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
st.info(f"Output directory: `{output_dir}`")
if compute_spectrograms:
images_dir = output_path / "images"
images_dir.mkdir(parents=True, exist_ok=True)
if seed >= 0:
np.random.seed(seed)
if export_as_mono and segment.channels > 1:
segment = segment.set_channels(1)
if save_to_disk:
st.info(f"Writing {num_clips} clip(s) to `{str(output_path)}`")
# TODO(hayk): Share code with riffusion.cli.sample_clips.
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms"
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
streamlit_util.display_and_download_audio(
clip,
name=clip_name,
extension=extension,
)
if save_to_disk:
clip_path = output_path / f"{clip_name}.{extension}"
clip.export(clip_path, format=extension)
if compute_spectrograms:
params = SpectrogramParams()
image = streamlit_util.spectrogram_image_from_audio(
clip,
params=params,
device=device,
)
st.image(image)
if save_to_disk:
image_path = images_dir / f"{clip_name}.jpeg"
image.save(image_path)
if save_to_disk:
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
if __name__ == "__main__":
render_sample_clips()

View File

@ -0,0 +1,105 @@
import typing as T
from pathlib import Path
import pydub
import streamlit as st
from riffusion.audio_splitter import split_audio
from riffusion.streamlit import util as streamlit_util
from riffusion.util import audio_util
def render_split_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":scissors: Audio Splitter")
st.write(
"""
Split audio into individual instrument stems.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool allows uploading an audio file of arbitrary length and splits it into
stems of vocals, drums, bass, and other. It does this using a deep network that
sweeps over the audio in clips, extracts the stems, and then cross fades the clips
back together to construct the full length stems. It's particularly useful in
combination with audio_to_audio, for example to split and preserve vocals while
modifying the rest of the track with a prompt. Or, to pull out drums to add later
in a DAW.
"""
)
device = streamlit_util.select_device(st.sidebar)
extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"]
extension = st.sidebar.selectbox(
"Output format",
options=extension_options,
index=extension_options.index("mp3"),
)
assert extension is not None
audio_file = st.file_uploader(
"Upload audio",
type=extension_options,
label_visibility="collapsed",
)
stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
recombine = st.sidebar.multiselect(
"Recombine",
options=stem_options,
default=[],
help="Recombine these stems at the end",
)
if not audio_file:
st.info("Upload audio to get started")
return
st.write("#### Original")
st.audio(audio_file)
counter = streamlit_util.StreamlitCounter()
st.button("Split", type="primary", on_click=counter.increment)
if counter.value == 0:
return
segment = streamlit_util.load_audio_file(audio_file)
# Split
stems = split_audio_cached(segment, device=device)
input_name = Path(audio_file.name).stem
# Display each
for name in stem_options:
stem = stems[name.lower()]
st.write(f"#### Stem: {name}")
output_name = f"{input_name}_{name.lower()}"
streamlit_util.display_and_download_audio(stem, output_name, extension=extension)
if recombine:
recombine_lower = [r.lower() for r in recombine]
segments = [s for name, s in stems.items() if name in recombine_lower]
recombined = audio_util.overlay_segments(segments)
# Display
st.write(f"#### Recombined: {', '.join(recombine)}")
output_name = f"{input_name}_{'_'.join(recombine_lower)}"
streamlit_util.display_and_download_audio(recombined, output_name, extension=extension)
@st.cache
def split_audio_cached(
segment: pydub.AudioSegment, device: str = "cuda"
) -> T.Dict[str, pydub.AudioSegment]:
return split_audio(segment, device=device)
if __name__ == "__main__":
render_split_audio()

View File

@ -0,0 +1,118 @@
import typing as T
import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_text_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":pencil2: Text to Audio")
st.write(
"""
Generate audio from text prompts.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool runs riffusion in the simplest text to image form to generate an audio
clip from a text prompt. There is no seed image or interpolation here. This mode
allows more diversity and creativity than when using a seed image, but it also
leads to having less control. Play with the seed to get infinite variations.
"""
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
lora_path = st.sidebar.text_input("Lora Path", "")
lora_scale = st.sidebar.number_input("Lora Scale", value=1.0)
with st.form("Inputs"):
prompt = st.text_input("Prompt")
negative_prompt = st.text_input("Negative prompt")
row = st.columns(4)
num_clips = T.cast(
int,
row[0].number_input(
"Number of clips",
value=1,
min_value=1,
max_value=25,
help="How many outputs to generate (seed gets incremented)",
),
)
starting_seed = T.cast(
int,
row[1].number_input(
"Seed",
value=42,
help="Change this to generate different variations",
),
)
st.form_submit_button("Riff", type="primary")
with st.sidebar:
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
width = T.cast(int, st.number_input("Width", value=512))
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
scheduler = st.selectbox(
"Scheduler",
options=streamlit_util.SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
if not prompt:
st.info("Enter a prompt")
return
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
seed = starting_seed
for i in range(1, num_clips + 1):
st.write(f"#### Riff {i} / {num_clips} - Seed {seed}")
image = streamlit_util.run_txt2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance=guidance,
negative_prompt=negative_prompt,
seed=seed,
width=width,
height=512,
device=device,
scheduler=scheduler,
lora_path=lora_path,
lora_scale=lora_scale,
)
st.image(image)
segment = streamlit_util.audio_segment_from_spectrogram_image(
image=image,
params=params,
device=device,
)
streamlit_util.display_and_download_audio(
segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension
)
seed += 1
if __name__ == "__main__":
render_text_to_audio()

View File

@ -0,0 +1,147 @@
import json
import typing as T
from pathlib import Path
import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
# Example input json file to process in batch
EXAMPLE_INPUT = """
{
"params": {
"seed": 42,
"num_inference_steps": 50,
"guidance": 7.0,
"width": 512,
},
"entries": [
{
"prompt": "Church bells"
},
{
"prompt": "electronic beats",
"negative_prompt": "drums"
},
{
"prompt": "classical violin concerto"
}
]
}
"""
def render_text_to_audio_batch() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":scroll: Text to Audio Batch")
st.write(
"""
Generate audio in batch from a JSON file of text prompts.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool is a batch form of text_to_audio, where the inputs are read in from a JSON
file. The input file contains a global params block and a list of entries with positive
and negative prompts. It's useful for automating a larger set of generations. See the
example inputs below for the format of the file.
"""
)
device = streamlit_util.select_device(st.sidebar)
# Upload a JSON file
json_file = st.file_uploader(
"JSON file",
type=["json"],
label_visibility="collapsed",
)
# Handle the null case
if json_file is None:
st.info("Upload a JSON file containing params and prompts")
with st.expander("Example inputs.json", expanded=False):
st.code(EXAMPLE_INPUT)
return
# Read in and print it
data = json.loads(json_file.read())
with st.expander("Input Data", expanded=False):
st.json(data)
params = data["params"]
entries = data["entries"]
show_images = st.sidebar.checkbox("Show Images", False)
# Optionally specify an output directory
output_dir = st.sidebar.text_input("Output Directory", "")
output_path: T.Optional[Path] = None
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for i, entry in enumerate(entries):
st.write(f"#### Entry {i + 1} / {len(entries)}")
negative_prompt = entry.get("negative_prompt", None)
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
image = streamlit_util.run_txt2img(
prompt=entry["prompt"],
negative_prompt=negative_prompt,
seed=params.get("seed", 42),
num_inference_steps=params.get("num_inference_steps", 50),
guidance=params.get("guidance", 7.0),
width=params.get("width", 512),
height=512,
device=device,
)
if show_images:
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
p_spectrogram = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
output_format = "mp3"
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=p_spectrogram,
device=device,
output_format=output_format,
)
st.audio(audio_bytes)
if output_path:
prompt_slug = entry["prompt"].replace(" ", "_")
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
image.save(image_path, format="JPEG")
entry["image_path"] = str(image_path)
audio_path = (
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
)
audio_path.write_bytes(audio_bytes.getbuffer())
entry["audio_path"] = str(audio_path)
if output_path:
output_json_path = output_path / "index.json"
output_json_path.write_text(json.dumps(data, indent=4))
st.info(f"Output written to {str(output_path)}")
else:
st.info("Enter output directory in sidebar to save to disk")
if __name__ == "__main__":
render_text_to_audio_batch()

View File

@ -0,0 +1,43 @@
import streamlit as st
def render_main():
st.set_page_config(layout="wide", page_icon="🎸")
st.title(":guitar: Riffusion Playground")
left, right = st.columns(2)
with left:
create_link(":pencil2: Text to Audio", "/text_to_audio")
st.write("Generate audio clips from text prompts.")
create_link(":wave: Audio to Audio", "/audio_to_audio")
st.write("Upload audio and modify with text prompt (interpolation supported).")
create_link(":performing_arts: Interpolation", "/interpolation")
st.write("Interpolate between prompts in the latent space.")
create_link(":scissors: Audio Splitter", "/split_audio")
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
with right:
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
st.write("Generate audio in batch from a JSON file of text prompts.")
create_link(":paperclip: Sample Clips", "/sample_clips")
st.write("Export short clips from an audio file.")
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
st.write("Reconstruct audio from spectrogram images.")
def create_link(name: str, url: str) -> None:
st.markdown(
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
unsafe_allow_html=True,
)
if __name__ == "__main__":
render_main()

457
riffusion/streamlit/util.py Normal file
View File

@ -0,0 +1,457 @@
"""
Streamlit utilities (mostly cached wrappers around riffusion code).
"""
import io
import threading
import typing as T
from pathlib import Path
import pydub
import streamlit as st
import torch
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from PIL import Image
from riffusion.audio_splitter import AudioSplitter
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
# TODO(hayk): Add URL params
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
SCHEDULER_OPTIONS = [
"PNDMScheduler",
"DDIMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@st.experimental_singleton
def load_riffusion_checkpoint(
checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
device: str = "cuda",
) -> RiffusionPipeline:
"""
Load the riffusion pipeline.
"""
return RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=not no_traced_unet,
device=device,
)
@st.experimental_singleton
def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
lora_path: T.Optional[str] = None,
lora_scale: float = 1.0,
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
if lora_path:
if not Path(lora_path).is_file() or Path(lora_path).is_dir():
raise RuntimeError("Bad lora path")
from lora_diffusion import patch_pipe, tune_lora_scale
patch_pipe(
pipeline,
lora_path,
patch_text=True,
patch_ti=True,
patch_unet=True,
)
tune_lora_scale(pipeline.unet, lora_scale)
return pipeline
def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
"""
Construct a denoising scheduler from a string.
"""
if scheduler == "PNDMScheduler":
from diffusers import PNDMScheduler
return PNDMScheduler.from_config(config)
elif scheduler == "DPMSolverMultistepScheduler":
from diffusers import DPMSolverMultistepScheduler
return DPMSolverMultistepScheduler.from_config(config)
elif scheduler == "DDIMScheduler":
from diffusers import DDIMScheduler
return DDIMScheduler.from_config(config)
elif scheduler == "LMSDiscreteScheduler":
from diffusers import LMSDiscreteScheduler
return LMSDiscreteScheduler.from_config(config)
elif scheduler == "EulerDiscreteScheduler":
from diffusers import EulerDiscreteScheduler
return EulerDiscreteScheduler.from_config(config)
elif scheduler == "EulerAncestralDiscreteScheduler":
from diffusers import EulerAncestralDiscreteScheduler
return EulerAncestralDiscreteScheduler.from_config(config)
else:
raise ValueError(f"Unknown scheduler {scheduler}")
@st.experimental_singleton
def pipeline_lock() -> threading.Lock:
"""
Singleton lock used to prevent concurrent access to any model pipeline.
"""
return threading.Lock()
@st.experimental_singleton
def load_stable_diffusion_img2img_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
lora_path: T.Optional[str] = None,
lora_scale: float = 1.0,
) -> StableDiffusionImg2ImgPipeline:
"""
Load the image to image pipeline.
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
# TODO reduce duplication
if lora_path:
if not Path(lora_path).is_file() or Path(lora_path).is_dir():
raise RuntimeError("Bad lora path")
from lora_diffusion import patch_pipe, tune_lora_scale
patch_pipe(
pipeline,
lora_path,
patch_text=True,
patch_ti=True,
patch_unet=True,
)
tune_lora_scale(pipeline.unet, lora_scale)
return pipeline
@st.experimental_memo
def run_txt2img(
prompt: str,
num_inference_steps: int,
guidance: float,
negative_prompt: str,
seed: int,
width: int,
height: int,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
lora_path: T.Optional[str] = None,
lora_scale: float = 1.0,
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
with pipeline_lock():
pipeline = load_stable_diffusion_pipeline(
device=device,
scheduler=scheduler,
lora_path=lora_path,
lora_scale=lora_scale,
)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)
return output["images"][0]
@st.experimental_singleton
def spectrogram_image_converter(
params: SpectrogramParams,
device: str = "cuda",
) -> SpectrogramImageConverter:
return SpectrogramImageConverter(params=params, device=device)
@st.cache
def spectrogram_image_from_audio(
segment: pydub.AudioSegment,
params: SpectrogramParams,
device: str = "cuda",
) -> Image.Image:
converter = spectrogram_image_converter(params=params, device=device)
return converter.spectrogram_image_from_audio(segment)
@st.experimental_memo
def audio_segment_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
device: str = "cuda",
) -> pydub.AudioSegment:
converter = spectrogram_image_converter(params=params, device=device)
return converter.audio_from_spectrogram_image(image)
@st.experimental_memo
def audio_bytes_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
device: str = "cuda",
output_format: str = "mp3",
) -> io.BytesIO:
segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device)
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=output_format)
return audio_bytes
def select_device(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a torch device, with an intelligent default.
"""
default_device = "cpu"
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
device_options = ["cuda", "cpu", "mps"]
device = st.sidebar.selectbox(
"Device",
options=device_options,
index=device_options.index(default_device),
help="Which compute device to use. CUDA is recommended.",
)
assert device is not None
return device
def select_audio_extension(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select an audio extension, with an intelligent default.
"""
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
extension = container.selectbox(
"Output format",
options=AUDIO_EXTENSIONS,
index=AUDIO_EXTENSIONS.index(default),
)
assert extension is not None
return extension
def select_scheduler(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a scheduler.
"""
scheduler = st.sidebar.selectbox(
"Scheduler",
options=SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
return scheduler
@st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)
@st.experimental_singleton
def get_audio_splitter(device: str = "cuda"):
return AudioSplitter(device=device)
@st.experimental_singleton
def load_magic_mix_pipeline(device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0]):
pipeline = DiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
custom_pipeline="magic_mix",
).to(device)
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
return pipeline
@st.cache
def run_img2img_magic_mix(
prompt: str,
init_image: Image.Image,
num_inference_steps: int,
guidance_scale: float,
seed: int,
kmin: float,
kmax: float,
mix_factor: float,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
):
"""
Run the magic mix pipeline for img2img.
"""
with pipeline_lock():
pipeline = load_magic_mix_pipeline(
device=device,
scheduler=scheduler,
)
return pipeline(
init_image,
prompt=prompt,
kmin=kmin,
kmax=kmax,
mix_factor=mix_factor,
seed=seed,
guidance_scale=guidance_scale,
steps=num_inference_steps,
)
@st.cache
def run_img2img(
prompt: str,
init_image: Image.Image,
denoising_strength: float,
num_inference_steps: int,
guidance_scale: float,
seed: int,
negative_prompt: T.Optional[str] = None,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
lora_path: T.Optional[str] = None,
lora_scale: float = 1.0,
) -> Image.Image:
with pipeline_lock():
pipeline = load_stable_diffusion_img2img_pipeline(
device=device,
scheduler=scheduler,
lora_path=lora_path,
lora_scale=lora_scale,
)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)
def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
if progress_callback is not None:
progress_callback(step / num_expected_steps)
result = pipeline(
prompt=prompt,
image=init_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt or None,
num_images_per_prompt=1,
generator=generator,
callback=callback,
callback_steps=1,
)
return result.images[0]
class StreamlitCounter:
"""
Simple counter stored in streamlit session state.
"""
def __init__(self, key="_counter"):
self.key = key
if not st.session_state.get(self.key):
st.session_state[self.key] = 0
def increment(self):
st.session_state[self.key] += 1
@property
def value(self):
return st.session_state[self.key]
def display_and_download_audio(
segment: pydub.AudioSegment,
name: str,
extension: str = "mp3",
) -> None:
"""
Display the given audio segment and provide a button to download it with
a proper file name, since st.audio doesn't support that.
"""
mime_type = f"audio/{extension}"
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=extension)
st.audio(audio_bytes, format=mime_type)
st.download_button(
f"{name}.{extension}",
data=audio_bytes,
file_name=f"{name}.{extension}",
mime=mime_type,
)

View File

View File

@ -0,0 +1,99 @@
"""
Audio utility functions.
"""
import io
import typing as T
import numpy as np
import pydub
from scipy.io import wavfile
def audio_from_waveform(
samples: np.ndarray, sample_rate: int, normalize: bool = False
) -> pydub.AudioSegment:
"""
Convert a numpy array of samples of a waveform to an audio segment.
Args:
samples: (channels, samples) array
"""
# Normalize volume to fit in int16
if normalize:
samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
# Transpose and convert to int16
samples = samples.transpose(1, 0)
samples = samples.astype(np.int16)
# Write to the bytes of a WAV file
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, samples)
wav_bytes.seek(0)
# Read into pydub
return pydub.AudioSegment.from_wav(wav_bytes)
def apply_filters(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment:
"""
Apply post-processing filters to the audio segment to compress it and
keep at a -10 dBFS level.
"""
# TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
# TODO(hayk): Is this going to make audio unbalanced between sequential clips?
if compression:
segment = pydub.effects.normalize(
segment,
headroom=0.1,
)
segment = segment.apply_gain(-10 - segment.dBFS)
# TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU
segment = pydub.effects.compress_dynamic_range(
segment,
threshold=-20.0,
ratio=4.0,
attack=5.0,
release=50.0,
)
desired_db = -12
segment = segment.apply_gain(desired_db - segment.dBFS)
segment = pydub.effects.normalize(
segment,
headroom=0.1,
)
return segment
def stitch_segments(
segments: T.Sequence[pydub.AudioSegment], crossfade_s: float
) -> pydub.AudioSegment:
"""
Stitch together a sequence of audio segments with a crossfade between each segment.
"""
crossfade_ms = int(crossfade_s * 1000)
combined_segment = segments[0]
for segment in segments[1:]:
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
return combined_segment
def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment:
"""
Overlay a sequence of audio segments on top of each other.
"""
assert len(segments) > 0
output: pydub.AudioSegment = None
for segment in segments:
if output is None:
output = segment
else:
output = output.overlay(segment)
return output

View File

@ -0,0 +1,9 @@
import base64
import io
def encode(buffer: io.BytesIO) -> str:
"""
Encode the given buffer as base64.
"""
return base64.encodebytes(buffer.getvalue()).decode("ascii")

View File

@ -0,0 +1,60 @@
"""
FFT tools to analyze frequency content of audio segments. This is not code for
dealing with spectrogram images, but for analysis of waveforms.
"""
import struct
import typing as T
import numpy as np
import plotly.graph_objects as go
import pydub
from scipy.fft import rfft, rfftfreq
def plot_ffts(
segments: T.Dict[str, pydub.AudioSegment],
title: str = "FFT",
min_frequency: float = 20,
max_frequency: float = 20000,
) -> None:
"""
Plot an FFT analysis of the given audio segments.
"""
ffts = {name: compute_fft(seg) for name, seg in segments.items()}
fig = go.Figure(
data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
layout={"title": title},
)
fig.update_xaxes(
range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
type="log",
title="Frequency",
)
fig.update_yaxes(title="Value")
fig.show()
def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
"""
Compute the FFT of the given audio segment as a mono signal.
Returns:
frequencies: FFT computed frequencies
amplitudes: Amplitudes of each frequency
"""
# Convert to mono if needed.
if sound.channels > 1:
sound = sound.set_channels(1)
sample_rate = sound.frame_rate
num_samples = int(sound.frame_count())
samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
fft_values = rfft(samples)
amplitudes = np.abs(fft_values)
frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
return frequencies, amplitudes

View File

@ -0,0 +1,122 @@
"""
Module for converting between spectrograms tensors and spectrogram images, as well as
general helpers for operating on pillow images.
"""
import typing as T
import numpy as np
from PIL import Image
from riffusion.spectrogram_params import SpectrogramParams
def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
"""
Compute a spectrogram image from a spectrogram magnitude array.
This is the inverse of spectrogram_from_image, except for discretization error from
quantizing to uint8.
Args:
spectrogram: (channels, frequency, time)
power: A power curve to apply to the spectrogram to preserve contrast
Returns:
image: (frequency, time, channels)
"""
# Rescale to 0-1
max_value = np.max(spectrogram)
data = spectrogram / max_value
# Apply the power curve
data = np.power(data, power)
# Rescale to 0-255
data = data * 255
# Invert
data = 255 - data
# Convert to uint8
data = data.astype(np.uint8)
# Munge channels into a PIL image
if data.shape[0] == 1:
# TODO(hayk): Do we want to write single channel to disk instead?
image = Image.fromarray(data[0], mode="L").convert("RGB")
elif data.shape[0] == 2:
data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
image = Image.fromarray(data, mode="RGB")
else:
raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
# Flip Y
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
return image
def spectrogram_from_image(
image: Image.Image,
power: float = 0.25,
stereo: bool = False,
max_value: float = 30e6,
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
This is the inverse of image_from_spectrogram, except for discretization error from
quantizing to uint8.
Args:
image: (frequency, time, channels)
power: The power curve applied to the spectrogram
stereo: Whether the spectrogram encodes stereo data
max_value: The max value of the original spectrogram. In practice doesn't matter.
Returns:
spectrogram: (channels, frequency, time)
"""
# Convert to RGB if single channel
if image.mode in ("P", "L"):
image = image.convert("RGB")
# Flip Y
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
# Munge channels into a numpy array of (channels, frequency, time)
data = np.array(image).transpose(2, 0, 1)
if stereo:
# Take the G and B channels as done in image_from_spectrogram
data = data[[1, 2], :, :]
else:
data = data[0:1, :, :]
# Convert to floats
data = data.astype(np.float32)
# Invert
data = 255 - data
# Rescale to 0-1
data = data / 255
# Reverse the power curve
data = np.power(data, 1 / power)
# Rescale to max value
data = data * max_value
return data
def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]:
"""
Get the EXIF data from a PIL image as a dict.
"""
exif = pil_image.getexif()
if exif is None or len(exif) == 0:
return {}
return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()}

View File

@ -0,0 +1,48 @@
import warnings
import numpy as np
import torch
def check_device(device: str, backup: str = "cpu") -> str:
"""
Check that the device is valid and available. If not,
"""
cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
if cuda_not_found or mps_not_found:
warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
return backup
return device
def slerp(
t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995
) -> torch.Tensor:
"""
Helper function to spherically interpolate two arrays v1 v2.
"""
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > dot_threshold:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2

BIN
seed_images/agile.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

BIN
seed_images/marim.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 225 B

BIN
seed_images/motorway.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

BIN
seed_images/og_beat.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

BIN
seed_images/vibes.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

0
test/__init__.py Normal file
View File

View File

@ -0,0 +1,99 @@
import typing as T
import numpy as np
from PIL import Image
from riffusion.cli import audio_to_image
from riffusion.spectrogram_params import SpectrogramParams
from .test_case import TestCase
class AudioToImageTest(TestCase):
"""
Test riffusion.cli audio-to-image
"""
@classmethod
def default_params(cls) -> T.Dict:
return dict(
step_size_ms=10,
num_frequencies=512,
# TODO(hayk): Change these to [20, 20000] once a model is updated
min_frequency=0,
max_frequency=10000,
stereo=False,
device=cls.DEVICE,
)
def test_audio_to_image(self) -> None:
"""
Test audio-to-image with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_stereo(self) -> None:
"""
Test audio-to-image with stereo=True.
"""
params = self.default_params()
params["stereo"] = True
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir("audio_to_image_")
if params["stereo"]:
stem = f"{audio_path.stem}_stereo"
else:
stem = audio_path.stem
image_path = output_dir / f"{stem}.png"
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
# Check that the image exists
self.assertTrue(image_path.exists())
pil_image = Image.open(image_path)
# Check the image mode
self.assertEqual(pil_image.mode, "RGB")
# Check the image dimensions
duration_ms = 5678
self.assertTrue(str(duration_ms) in audio_path.name)
expected_image_width = round(duration_ms / params["step_size_ms"])
self.assertEqual(pil_image.width, expected_image_width)
self.assertEqual(pil_image.height, params["num_frequencies"])
# Get channels as numpy arrays
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
self.assertEqual(len(channels), 3)
if params["stereo"]:
# Check that the first channel is zero
self.assertTrue(np.all(channels[0] == 0))
else:
# Check that all channels are the same
self.assertTrue(np.all(channels[0] == channels[1]))
self.assertTrue(np.all(channels[0] == channels[2]))
# Check that the image has exif data
exif = pil_image.getexif()
self.assertIsNotNone(exif)
params_from_exif = SpectrogramParams.from_exif(exif)
expected_params = SpectrogramParams(
stereo=params["stereo"],
step_size_ms=params["step_size_ms"],
num_frequencies=params["num_frequencies"],
max_frequency=params["max_frequency"],
)
self.assertTrue(params_from_exif == expected_params)

View File

@ -0,0 +1,71 @@
from pathlib import Path
import pydub
from riffusion.cli import image_to_audio
from .test_case import TestCase
class ImageToAudioTest(TestCase):
"""
Test riffusion.cli image-to-audio
"""
def test_image_to_audio_mono(self) -> None:
self.helper_image_to_audio(
song_dir=self.TEST_DATA_PATH / "tired_traveler",
clip_name="clip_2_start_103694_ms_duration_5678_ms",
stereo=False,
)
def test_image_to_audio_stereo(self) -> None:
self.helper_image_to_audio(
song_dir=self.TEST_DATA_PATH / "tired_traveler",
clip_name="clip_2_start_103694_ms_duration_5678_ms",
stereo=True,
)
def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
if stereo:
image_stem = clip_name + "_stereo"
else:
image_stem = clip_name
image_path = song_dir / "images" / f"{image_stem}.png"
output_dir = self.get_tmp_dir("image_to_audio_")
audio_path = output_dir / f"{image_path.stem}.wav"
image_to_audio(
image=str(image_path),
audio=str(audio_path),
device=self.DEVICE,
)
# Check that the audio exists
self.assertTrue(audio_path.exists())
# Load the reconstructed audio and the original clip
segment = pydub.AudioSegment.from_file(str(audio_path))
expected_segment = pydub.AudioSegment.from_file(
str(song_dir / "clips" / f"{clip_name}.wav")
)
# Check sample rate
self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
# Check duration
actual_duration_ms = round(segment.duration_seconds * 1000)
expected_duration_ms = round(expected_segment.duration_seconds * 1000)
self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
# Check the number of channels
self.assertEqual(expected_segment.channels, 2)
if stereo:
self.assertEqual(segment.channels, 2)
else:
self.assertEqual(segment.channels, 1)
if __name__ == "__main__":
TestCase.main()

65
test/image_util_test.py Normal file
View File

@ -0,0 +1,65 @@
import numpy as np
import pydub
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import image_util
from .test_case import TestCase
class ImageUtilTest(TestCase):
"""
Test riffusion.util.image_util
"""
def test_spectrogram_to_image_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono
segment = segment.set_channels(1)
# Compute a spectrogram with default params
params = SpectrogramParams(sample_rate=segment.frame_rate)
converter = SpectrogramConverter(params=params, device=self.DEVICE)
spectrogram = converter.spectrogram_from_audio(segment)
# Compute the image from the spectrogram
image = image_util.image_from_spectrogram(
spectrogram=spectrogram,
power=params.power_for_image,
)
# Save the max value
max_value = np.max(spectrogram)
# Compute the spectrogram from the image
spectrogram_reversed = image_util.spectrogram_from_image(
image=image,
max_value=max_value,
power=params.power_for_image,
stereo=params.stereo,
)
# Check the shapes
self.assertEqual(spectrogram.shape, spectrogram_reversed.shape)
# Check the max values
self.assertEqual(np.max(spectrogram), np.max(spectrogram_reversed))
# Check the median values
self.assertTrue(
np.allclose(np.median(spectrogram), np.median(spectrogram_reversed), rtol=0.05)
)
# Make sure all values are somewhat similar, but allow for discretization error
# TODO(hayk): Investigate error more closely
self.assertTrue(np.allclose(spectrogram, spectrogram_reversed, rtol=0.15))

24
test/linter_test.py Normal file
View File

@ -0,0 +1,24 @@
import subprocess
from pathlib import Path
from .test_case import TestCase
class LinterTest(TestCase):
"""
Test that ruff, black, and mypy run cleanly.
"""
HOME = Path(__file__).parent.parent
def test_ruff(self) -> None:
code = subprocess.check_call(["ruff", str(self.HOME)])
self.assertEqual(code, 0)
def test_black(self) -> None:
code = subprocess.check_call(["black", "--check", str(self.HOME)])
self.assertEqual(code, 0)
def test_mypy(self) -> None:
code = subprocess.check_call(["mypy", str(self.HOME)])
self.assertEqual(code, 0)

32
test/print_exif_test.py Normal file
View File

@ -0,0 +1,32 @@
import contextlib
import io
from riffusion.cli import print_exif
from .test_case import TestCase
class PrintExifTest(TestCase):
"""
Test riffusion.cli print-exif
"""
def test_print_exif(self) -> None:
"""
Test print-exif.
"""
image_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "images"
/ "clip_2_start_103694_ms_duration_5678_ms.png"
)
# Redirect stdout
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
print_exif(image=str(image_path))
# Check that a couple of values are printed
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())

88
test/sample_clips_test.py Normal file
View File

@ -0,0 +1,88 @@
import typing as T
import pydub
from riffusion.cli import sample_clips
from .test_case import TestCase
class SampleClipsTest(TestCase):
"""
Test riffusion.cli sample-clips
"""
@staticmethod
def default_params() -> T.Dict:
return dict(
num_clips=3,
duration_ms=5678,
mono=False,
extension="wav",
seed=42,
)
def test_sample_clips(self) -> None:
"""
Test sample-clips with default params.
"""
params = self.default_params()
self.helper_test_with_params(params)
def test_mono(self) -> None:
"""
Test sample-clips with mono=True.
"""
params = self.default_params()
params["mono"] = True
params["num_clips"] = 1
self.helper_test_with_params(params)
def test_mp3(self) -> None:
"""
Test sample-clips with extension=mp3.
"""
if pydub.AudioSegment.converter is None:
self.skipTest("skipping, ffmpeg not found")
params = self.default_params()
params["extension"] = "mp3"
params["num_clips"] = 1
self.helper_test_with_params(params)
def helper_test_with_params(self, params: T.Dict) -> None:
"""
Test sample-clips with the given params.
"""
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
output_dir = self.get_tmp_dir("sample_clips_")
sample_clips(
audio=str(audio_path),
output_dir=str(output_dir),
**params,
)
# For each file in output dir
counter = 0
for clip_path in output_dir.iterdir():
# Check that it has the right extension
self.assertEqual(clip_path.suffix, f".{params['extension']}")
# Check that it has the right duration
segment = pydub.AudioSegment.from_file(clip_path)
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
# Check that it has the right number of channels
if params["mono"]:
self.assertEqual(segment.channels, 1)
else:
self.assertEqual(segment.channels, 2)
counter += 1
self.assertEqual(counter, params["num_clips"])
if __name__ == "__main__":
TestCase.main()

View File

@ -0,0 +1,86 @@
import dataclasses
import typing as T
import pydub
from riffusion.spectrogram_converter import SpectrogramConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import fft_util
from .test_case import TestCase
class SpectrogramConverterTest(TestCase):
"""
Test going from audio to spectrogram to audio, without converting to
an image, to check quality loss of the reconstruction.
This test allows comparing multiple sets of spectrogram params by listening to output audio
and by plotting their FFTs.
"""
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
def test_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono if desired
use_stereo = False
if use_stereo:
assert segment.channels == 2
else:
segment = segment.set_channels(1)
# Define named sets of parameters
param_sets: T.Dict[str, SpectrogramParams] = {}
param_sets["default"] = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=use_stereo,
step_size_ms=10,
min_frequency=20,
max_frequency=20000,
num_frequencies=512,
)
if self.DEBUG:
param_sets["freq_0_to_10k"] = dataclasses.replace(
param_sets["default"],
min_frequency=0,
max_frequency=10000,
)
segments: T.Dict[str, pydub.AudioSegment] = {
"original": segment,
}
for name, params in param_sets.items():
converter = SpectrogramConverter(params=params, device=self.DEVICE)
spectrogram = converter.spectrogram_from_audio(segment)
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
# Save segments to disk
for name, segment in segments.items():
audio_out = output_dir / f"{name}.wav"
segment.export(audio_out, format="wav")
print(f"Saved {audio_out}")
# Check params
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
self.assertEqual(segments["original"].channels, segments["default"].channels)
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
# If debugging, load up a browser tab plotting the FFTs
if self.DEBUG:
fft_util.plot_ffts(segments)

View File

@ -0,0 +1,97 @@
import dataclasses
import typing as T
import pydub
from PIL import Image
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import fft_util
from .test_case import TestCase
class SpectrogramImageConverterTest(TestCase):
"""
Test going from audio to spectrogram images to audio, testing the quality loss of the
end-to-end pipeline.
This test allows comparing multiple sets of spectrogram params by listening to output audio
and by plotting their FFTs.
See spectrogram_converter_test.py for a similar test that does not convert to images.
"""
def test_round_trip(self) -> None:
audio_path = (
self.TEST_DATA_PATH
/ "tired_traveler"
/ "clips"
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
)
output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_")
# Load up the audio file
segment = pydub.AudioSegment.from_file(audio_path)
# Convert to mono if desired
use_stereo = False
if use_stereo:
assert segment.channels == 2
else:
segment = segment.set_channels(1)
# Define named sets of parameters
param_sets: T.Dict[str, SpectrogramParams] = {}
param_sets["default"] = SpectrogramParams(
sample_rate=segment.frame_rate,
stereo=use_stereo,
step_size_ms=10,
min_frequency=20,
max_frequency=20000,
num_frequencies=512,
)
if self.DEBUG:
param_sets["freq_0_to_10k"] = dataclasses.replace(
param_sets["default"],
min_frequency=0,
max_frequency=10000,
)
segments: T.Dict[str, pydub.AudioSegment] = {
"original": segment,
}
images: T.Dict[str, Image.Image] = {}
for name, params in param_sets.items():
converter = SpectrogramImageConverter(params=params, device=self.DEVICE)
images[name] = converter.spectrogram_image_from_audio(segment)
segments[name] = converter.audio_from_spectrogram_image(
image=images[name],
apply_filters=True,
)
# Save images to disk
for name, image in images.items():
image_out = output_dir / f"{name}.png"
image.save(image_out, exif=image.getexif(), format="PNG")
print(f"Saved {image_out}")
# Save segments to disk
for name, segment in segments.items():
audio_out = output_dir / f"{name}.wav"
segment.export(audio_out, format="wav")
print(f"Saved {audio_out}")
# Check params
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
self.assertEqual(segments["original"].channels, segments["default"].channels)
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
# If debugging, load up a browser tab plotting the FFTs
if self.DEBUG:
fft_util.plot_ffts(segments)

48
test/test_case.py Normal file
View File

@ -0,0 +1,48 @@
import os
import shutil
import tempfile
import typing as T
import unittest
import warnings
from pathlib import Path
class TestCase(unittest.TestCase):
"""
Base class for tests.
"""
# Where checked-in test data is stored
TEST_DATA_PATH = Path(__file__).resolve().parent / "test_data"
# Whether to run tests in debug mode (e.g. don't clean up temporary directories, show plots)
DEBUG = bool(os.environ.get("RIFFUSION_TEST_DEBUG"))
# Which torch device to use for tests
DEVICE = os.environ.get("RIFFUSION_TEST_DEVICE", "cuda")
@staticmethod
def main(*args: T.Any, **kwargs: T.Any) -> None:
"""
Run the tests.
"""
unittest.main(*args, **kwargs)
@classmethod
def setUpClass(cls):
warnings.filterwarnings("ignore", category=ResourceWarning)
def get_tmp_dir(self, prefix: str) -> Path:
"""
Create a temporary directory.
"""
tmp_dir = tempfile.mkdtemp(prefix=prefix)
# Clean up the temporary directory if not debugging
if not self.DEBUG:
self.addCleanup(lambda: shutil.rmtree(tmp_dir, ignore_errors=True))
dir_path = Path(tmp_dir)
assert dir_path.is_dir()
return dir_path

7
test/test_data/README.md Normal file
View File

@ -0,0 +1,7 @@
# Test Data
### tired_traveler
* Song: Tired traveler on the way to home
* Artist: Andrew Codeman
* Source: https://freemusicarchive.org/

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

Binary file not shown.