parent
420674148a
commit
1335afb72f
135
README.md
135
README.md
|
@ -1,51 +1,130 @@
|
|||
# Riffusion
|
||||
|
||||
Riffusion is a technique for real-time music and audio generation with stable diffusion.
|
||||
Riffusion is a library for real-time music and audio generation with stable diffusion.
|
||||
|
||||
Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
|
||||
|
||||
* Inference server: https://github.com/riffusion/riffusion
|
||||
This repository contains the core riffusion image and audio processing code and supporting apps,
|
||||
including:
|
||||
|
||||
* diffusion pipeline that performs prompt interpolation combined with image conditioning
|
||||
* package for (approximately) converting between spectrogram images and audio clips
|
||||
* interactive playground using streamlit
|
||||
* command-line tool for common tasks
|
||||
* flask server to provide model inference via API
|
||||
* various third party integrations
|
||||
* test suite
|
||||
|
||||
Related repositories:
|
||||
* Web app: https://github.com/riffusion/riffusion-app
|
||||
* Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
|
||||
|
||||
This repository contains the Python backend does the model inference and audio processing, including:
|
||||
## Citation
|
||||
|
||||
* a diffusers pipeline that performs prompt interpolation combined with image conditioning
|
||||
* a module for (approximately) converting between spectrograms and waveforms
|
||||
* a flask server to provide model inference via API to the next.js app
|
||||
* a model template titled baseten.py for deploying as a Truss
|
||||
If you build on this work, please cite it as follows:
|
||||
|
||||
```
|
||||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
Tested with Python 3.9 and diffusers 0.9.0.
|
||||
Tested with Python 3.9 + 3.10 and diffusers 0.9.0.
|
||||
|
||||
To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds.
|
||||
|
||||
You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
|
||||
To run this model in real time, you need a GPU that can run stable diffusion with approximately 50
|
||||
steps in under five seconds. A 3090 or A10G will do it.
|
||||
|
||||
Install in a virtual Python environment:
|
||||
```
|
||||
conda create --name riffusion-inference python=3.9
|
||||
conda activate riffusion-inference
|
||||
conda create --name riffusion python=3.9
|
||||
conda activate riffusion
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
|
||||
If torchaudio has no audio backend, see
|
||||
[this issue](https://github.com/riffusion/riffusion/issues/12).
|
||||
|
||||
You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 – you'll need ffmpeg or libav.
|
||||
You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 –
|
||||
you'll need to install ffmpeg with `suod apt-get install ffmpeg` or `brew install ffmpeg`.
|
||||
|
||||
Guides:
|
||||
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
|
||||
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
|
||||
|
||||
## Backends
|
||||
|
||||
#### CUDA
|
||||
`cuda` is the recommended and most performant backend.
|
||||
|
||||
To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
|
||||
[install guide](https://pytorch.org/get-started/locally/) or
|
||||
[stable wheels](https://download.pytorch.org/whl/torch_stable.html). Check with:
|
||||
|
||||
```python3
|
||||
import torch
|
||||
torch.cuda.is_available()
|
||||
```
|
||||
|
||||
Also see [this issue](https://github.com/riffusion/riffusion/issues/3) for help.
|
||||
|
||||
#### CPU
|
||||
`cpu` works but is quite slow.
|
||||
|
||||
#### MPS
|
||||
The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
|
||||
particularly for audio processing. You may need to set
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1.
|
||||
|
||||
In addition, this backend is not deterministic.
|
||||
|
||||
## Command-line interface
|
||||
|
||||
Riffusion comes with a command line interface for performing common tasks.
|
||||
|
||||
See available commands:
|
||||
```
|
||||
python -m riffusion-cli -h
|
||||
```
|
||||
|
||||
Get help for a specific command:
|
||||
```
|
||||
python -m riffusion.cli image-to-audio -h
|
||||
```
|
||||
|
||||
Execute:
|
||||
```
|
||||
python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
|
||||
```
|
||||
|
||||
## Streamlit playground
|
||||
|
||||
Riffusion also has a streamlit app for interactive use and exploration.
|
||||
This app is called the Riffusion Playground.
|
||||
|
||||
Run with:
|
||||
```
|
||||
python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --bro
|
||||
wser.serverPort 8501
|
||||
```
|
||||
|
||||
And access at http://127.0.0.1:8501/
|
||||
|
||||
## Run the model server
|
||||
Start the Flask server:
|
||||
|
||||
Riffusion can be run as a flask server that provides inference via API. Run with:
|
||||
|
||||
```
|
||||
python -m riffusion.server --host 127.0.0.1 --port 3013
|
||||
```
|
||||
|
||||
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
|
||||
|
||||
Use the `--device` argument to specify the torch device to use.
|
||||
|
||||
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
|
||||
|
||||
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
|
||||
|
@ -79,15 +158,6 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
|
|||
}
|
||||
```
|
||||
|
||||
Use the `--device` argument to specify the torch device to use.
|
||||
|
||||
`cuda` is recommended.
|
||||
|
||||
`cpu` works but is quite slow.
|
||||
|
||||
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
|
||||
|
||||
## Test
|
||||
Tests live in the `test/` directory and are implemented with `unittest`.
|
||||
|
||||
|
@ -125,15 +195,6 @@ These are configured in `pyproject.toml`.
|
|||
|
||||
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
|
||||
|
||||
## Citation
|
||||
CI is run through GitHub Actions from `.github/workflows/ci.yml`.
|
||||
|
||||
If you build on this work, please cite it as follows:
|
||||
|
||||
```
|
||||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
Contributions are welcome through opening pull requests.
|
||||
|
|
|
@ -12,11 +12,15 @@ def render_text_to_audio() -> None:
|
|||
"""
|
||||
prompt = st.text_input("Prompt")
|
||||
negative_prompt = st.text_input("Negative prompt")
|
||||
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
|
||||
num_inference_steps = T.cast(int, st.sidebar.number_input("Inference steps", value=50))
|
||||
width = T.cast(int, st.sidebar.number_input("Width", value=512))
|
||||
height = T.cast(int, st.sidebar.number_input("Height", value=512))
|
||||
guidance = st.sidebar.number_input(
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
with st.sidebar.expander("Text to Audio Params", expanded=True):
|
||||
seed = T.cast(int, st.number_input("Seed", value=42))
|
||||
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
|
||||
width = T.cast(int, st.number_input("Width", value=512))
|
||||
height = T.cast(int, st.number_input("Height", value=512))
|
||||
guidance = st.number_input(
|
||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||
)
|
||||
|
||||
|
@ -24,8 +28,6 @@ def render_text_to_audio() -> None:
|
|||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import json
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_text_to_audio_batch() -> None:
|
||||
"""
|
||||
Render audio from text in batches, reading from a text file.
|
||||
"""
|
||||
json_file = st.file_uploader("JSON file", type=["json"])
|
||||
if json_file is None:
|
||||
st.info("Upload a JSON file of prompts")
|
||||
return
|
||||
|
||||
data = json.loads(json_file.read())
|
||||
|
||||
with st.expander("Data", expanded=False):
|
||||
st.json(data)
|
||||
|
||||
params = data["params"]
|
||||
entries = data["entries"]
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
show_images = st.sidebar.checkbox("Show Images", True)
|
||||
|
||||
# Optionally specify an output directory
|
||||
output_dir = st.sidebar.text_input("Output Directory", "")
|
||||
output_path: T.Optional[Path] = None
|
||||
if output_dir:
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
st.write(f"### Entry {i + 1} / {len(entries)}")
|
||||
|
||||
st.write(f"Prompt: {entry['prompt']}")
|
||||
|
||||
negative_prompt = entry.get("negative_prompt", None)
|
||||
st.write(f"Negative prompt: {negative_prompt}")
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=entry["prompt"],
|
||||
negative_prompt=negative_prompt,
|
||||
seed=params["seed"],
|
||||
num_inference_steps=params["num_inference_steps"],
|
||||
guidance=params["guidance"],
|
||||
width=params["width"],
|
||||
height=params["height"],
|
||||
device=device,
|
||||
)
|
||||
|
||||
if show_images:
|
||||
st.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
p_spectrogram = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
output_format = "mp3"
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image,
|
||||
params=p_spectrogram,
|
||||
device=device,
|
||||
output_format=output_format,
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
if output_path:
|
||||
prompt_slug = entry["prompt"].replace(" ", "_")
|
||||
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
|
||||
|
||||
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
||||
image.save(image_path, format="JPEG")
|
||||
entry["image_path"] = str(image_path)
|
||||
|
||||
audio_path = (
|
||||
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
||||
)
|
||||
audio_path.write_bytes(audio_bytes.getbuffer())
|
||||
entry["audio_path"] = str(audio_path)
|
||||
|
||||
if output_path:
|
||||
output_json_path = output_path / "index.json"
|
||||
output_json_path.write_text(json.dumps(data, indent=4))
|
||||
st.info(f"Output written to {str(output_path)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio_batch()
|
Loading…
Reference in New Issue