Add batch text to audio processing

Topic: streamlit_app
This commit is contained in:
Hayk Martiros 2022-12-26 21:32:42 -08:00
parent 420674148a
commit 1335afb72f
3 changed files with 206 additions and 46 deletions

135
README.md
View File

@ -1,51 +1,130 @@
# Riffusion
Riffusion is a technique for real-time music and audio generation with stable diffusion.
Riffusion is a library for real-time music and audio generation with stable diffusion.
Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
* Inference server: https://github.com/riffusion/riffusion
This repository contains the core riffusion image and audio processing code and supporting apps,
including:
* diffusion pipeline that performs prompt interpolation combined with image conditioning
* package for (approximately) converting between spectrogram images and audio clips
* interactive playground using streamlit
* command-line tool for common tasks
* flask server to provide model inference via API
* various third party integrations
* test suite
Related repositories:
* Web app: https://github.com/riffusion/riffusion-app
* Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
This repository contains the Python backend does the model inference and audio processing, including:
## Citation
* a diffusers pipeline that performs prompt interpolation combined with image conditioning
* a module for (approximately) converting between spectrograms and waveforms
* a flask server to provide model inference via API to the next.js app
* a model template titled baseten.py for deploying as a Truss
If you build on this work, please cite it as follows:
```
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}
```
## Install
Tested with Python 3.9 and diffusers 0.9.0.
Tested with Python 3.9 + 3.10 and diffusers 0.9.0.
To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds.
You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
To run this model in real time, you need a GPU that can run stable diffusion with approximately 50
steps in under five seconds. A 3090 or A10G will do it.
Install in a virtual Python environment:
```
conda create --name riffusion-inference python=3.9
conda activate riffusion-inference
conda create --name riffusion python=3.9
conda activate riffusion
python -m pip install -r requirements.txt
```
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
If torchaudio has no audio backend, see
[this issue](https://github.com/riffusion/riffusion/issues/12).
You can open and save WAV files with pure python. For opening and saving non-wav files like mp3 you'll need ffmpeg or libav.
You can open and save WAV files with pure python. For opening and saving non-wav files like mp3
you'll need to install ffmpeg with `suod apt-get install ffmpeg` or `brew install ffmpeg`.
Guides:
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
## Backends
#### CUDA
`cuda` is the recommended and most performant backend.
To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
[install guide](https://pytorch.org/get-started/locally/) or
[stable wheels](https://download.pytorch.org/whl/torch_stable.html). Check with:
```python3
import torch
torch.cuda.is_available()
```
Also see [this issue](https://github.com/riffusion/riffusion/issues/3) for help.
#### CPU
`cpu` works but is quite slow.
#### MPS
The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
particularly for audio processing. You may need to set
PYTORCH_ENABLE_MPS_FALLBACK=1.
In addition, this backend is not deterministic.
## Command-line interface
Riffusion comes with a command line interface for performing common tasks.
See available commands:
```
python -m riffusion-cli -h
```
Get help for a specific command:
```
python -m riffusion.cli image-to-audio -h
```
Execute:
```
python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
```
## Streamlit playground
Riffusion also has a streamlit app for interactive use and exploration.
This app is called the Riffusion Playground.
Run with:
```
python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --bro
wser.serverPort 8501
```
And access at http://127.0.0.1:8501/
## Run the model server
Start the Flask server:
Riffusion can be run as a flask server that provides inference via API. Run with:
```
python -m riffusion.server --host 127.0.0.1 --port 3013
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
Use the `--device` argument to specify the torch device to use.
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
@ -79,15 +158,6 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
}
```
Use the `--device` argument to specify the torch device to use.
`cuda` is recommended.
`cpu` works but is quite slow.
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
## Test
Tests live in the `test/` directory and are implemented with `unittest`.
@ -125,15 +195,6 @@ These are configured in `pyproject.toml`.
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
## Citation
CI is run through GitHub Actions from `.github/workflows/ci.yml`.
If you build on this work, please cite it as follows:
```
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}
```
Contributions are welcome through opening pull requests.

View File

@ -12,11 +12,15 @@ def render_text_to_audio() -> None:
"""
prompt = st.text_input("Prompt")
negative_prompt = st.text_input("Negative prompt")
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
num_inference_steps = T.cast(int, st.sidebar.number_input("Inference steps", value=50))
width = T.cast(int, st.sidebar.number_input("Width", value=512))
height = T.cast(int, st.sidebar.number_input("Height", value=512))
guidance = st.sidebar.number_input(
device = streamlit_util.select_device(st.sidebar)
with st.sidebar.expander("Text to Audio Params", expanded=True):
seed = T.cast(int, st.number_input("Seed", value=42))
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
width = T.cast(int, st.number_input("Width", value=512))
height = T.cast(int, st.number_input("Height", value=512))
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
@ -24,8 +28,6 @@ def render_text_to_audio() -> None:
st.info("Enter a prompt")
return
device = streamlit_util.select_device(st.sidebar)
image = streamlit_util.run_txt2img(
prompt=prompt,
num_inference_steps=num_inference_steps,

View File

@ -0,0 +1,97 @@
import json
import typing as T
from pathlib import Path
import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_text_to_audio_batch() -> None:
"""
Render audio from text in batches, reading from a text file.
"""
json_file = st.file_uploader("JSON file", type=["json"])
if json_file is None:
st.info("Upload a JSON file of prompts")
return
data = json.loads(json_file.read())
with st.expander("Data", expanded=False):
st.json(data)
params = data["params"]
entries = data["entries"]
device = streamlit_util.select_device(st.sidebar)
show_images = st.sidebar.checkbox("Show Images", True)
# Optionally specify an output directory
output_dir = st.sidebar.text_input("Output Directory", "")
output_path: T.Optional[Path] = None
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for i, entry in enumerate(entries):
st.write(f"### Entry {i + 1} / {len(entries)}")
st.write(f"Prompt: {entry['prompt']}")
negative_prompt = entry.get("negative_prompt", None)
st.write(f"Negative prompt: {negative_prompt}")
image = streamlit_util.run_txt2img(
prompt=entry["prompt"],
negative_prompt=negative_prompt,
seed=params["seed"],
num_inference_steps=params["num_inference_steps"],
guidance=params["guidance"],
width=params["width"],
height=params["height"],
device=device,
)
if show_images:
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
p_spectrogram = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
output_format = "mp3"
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=p_spectrogram,
device=device,
output_format=output_format,
)
st.audio(audio_bytes)
if output_path:
prompt_slug = entry["prompt"].replace(" ", "_")
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
image.save(image_path, format="JPEG")
entry["image_path"] = str(image_path)
audio_path = (
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
)
audio_path.write_bytes(audio_bytes.getbuffer())
entry["audio_path"] = str(audio_path)
if output_path:
output_json_path = output_path / "index.json"
output_json_path.write_text(json.dumps(data, indent=4))
st.info(f"Output written to {str(output_path)}")
if __name__ == "__main__":
render_text_to_audio_batch()