Stable diffusion for real-time music generation
Go to file
Hayk Martiros 6a5f572374 Add a test that verifies ruff, black, and mypy all pass
This fels good!

Topic: clean_rewrite
2022-12-26 17:47:11 -08:00
integrations Greatly simplify the server and baseten integration 2022-12-26 17:47:11 -08:00
riffusion Command line interface for common operations, plus tests 2022-12-26 17:47:11 -08:00
seed_images adding seed images 2022-12-13 23:33:52 -08:00
test Add a test that verifies ruff, black, and mypy all pass 2022-12-26 17:47:11 -08:00
.gitignore Update python project configuration files 2022-12-26 17:47:11 -08:00
CITATION Update python project configuration files 2022-12-26 17:47:11 -08:00
LICENSE Add license 2022-12-13 02:47:48 +00:00
README.md Update python project configuration files 2022-12-26 17:47:11 -08:00
dev_requirements.txt Update python project configuration files 2022-12-26 17:47:11 -08:00
pyproject.toml Update python project configuration files 2022-12-26 17:47:11 -08:00
requirements.txt Update python project configuration files 2022-12-26 17:47:11 -08:00

README.md

Riffusion

Riffusion is a technique for real-time music and audio generation with stable diffusion.

Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.

This repository contains the Python backend does the model inference and audio processing, including:

  • a diffusers pipeline that performs prompt interpolation combined with image conditioning
  • a module for (approximately) converting between spectrograms and waveforms
  • a flask server to provide model inference via API to the next.js app
  • a model template titled baseten.py for deploying as a Truss

Install

Tested with Python 3.9 and diffusers 0.9.0.

To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds.

You need to make sure you have torch and torchaudio installed with CUDA support. See the install guide or stable wheels.

conda create --name riffusion-inference python=3.9
conda activate riffusion-inference
python -m pip install -r requirements.txt

If torchaudio has no audio backend, see this issue.

You can open and save WAV files with pure python. For opening and saving non-wav files like mp3 you'll need ffmpeg or libav.

Guides:

Run the model server

Start the Flask server:

python -m riffusion.server --host 127.0.0.1 --port 3013

You can specify --checkpoint with your own directory or huggingface ID in diffusers format.

The model endpoint is now available at http://127.0.0.1:3013/run_inference via POST request.

Example input (see InferenceInput for the API):

{
  "alpha": 0.75,
  "num_inference_steps": 50,
  "seed_image_id": "og_beat",

  "start": {
    "prompt": "church bells on sunday",
    "seed": 42,
    "denoising": 0.75,
    "guidance": 7.0
  },

  "end": {
    "prompt": "jazz with piano",
    "seed": 123,
    "denoising": 0.75,
    "guidance": 7.0
  }
}

Example output (see InferenceOutput for the API):

{
  "image": "< base64 encoded JPEG image >",
  "audio": "< base64 encoded MP3 clip >"
}

Use the --device argument to specify the torch device to use.

cuda is recommended.

cpu works but is quite slow.

mps is supported for inference, but some operations fall back to CPU. You may need to set PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.

Test

Tests live in the test/ directory and are implemented with unittest.

To run all tests:

python -m unittest test/*_test.py

To run a single test:

python -m unittest test.audio_to_image_test

To preserve temporary outputs for debugging, set RIFFUSION_TEST_DEBUG:

RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test

To run a single test case:

python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo

To run tests using a specific torch device, set RIFFUSION_TEST_DEVICE. Tests should pass with cpu, cuda, and mps backends.

Development

Install additional packages for dev with pip install -r dev_requirements.txt.

  • Linter: ruff
  • Formatter: black
  • Type checker: mypy

These are configured in pyproject.toml.

The results of mypy ., black ., and ruff . must be clean to accept a PR.

Citation

If you build on this work, please cite it as follows:

@article{Forsgren_Martiros_2022,
  author = {Forsgren, Seth* and Martiros, Hayk*},
  title = {{Riffusion - Stable diffusion for real-time music generation}},
  url = {https://riffusion.com/about},
  year = {2022}
}