diff --git a/README.md b/README.md index 30864df..e07fb3c 100644 --- a/README.md +++ b/README.md @@ -1,51 +1,130 @@ # Riffusion -Riffusion is a technique for real-time music and audio generation with stable diffusion. +Riffusion is a library for real-time music and audio generation with stable diffusion. Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/. -* Inference server: https://github.com/riffusion/riffusion +This repository contains the core riffusion image and audio processing code and supporting apps, +including: + + * diffusion pipeline that performs prompt interpolation combined with image conditioning + * package for (approximately) converting between spectrogram images and audio clips + * interactive playground using streamlit + * command-line tool for common tasks + * flask server to provide model inference via API + * various third party integrations + * test suite + +Related repositories: * Web app: https://github.com/riffusion/riffusion-app * Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1 -This repository contains the Python backend does the model inference and audio processing, including: +## Citation - * a diffusers pipeline that performs prompt interpolation combined with image conditioning - * a module for (approximately) converting between spectrograms and waveforms - * a flask server to provide model inference via API to the next.js app - * a model template titled baseten.py for deploying as a Truss +If you build on this work, please cite it as follows: +``` +@article{Forsgren_Martiros_2022, + author = {Forsgren, Seth* and Martiros, Hayk*}, + title = {{Riffusion - Stable diffusion for real-time music generation}}, + url = {https://riffusion.com/about}, + year = {2022} +} +``` ## Install -Tested with Python 3.9 and diffusers 0.9.0. +Tested with Python 3.9 + 3.10 and diffusers 0.9.0. -To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds. - -You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html). +To run this model in real time, you need a GPU that can run stable diffusion with approximately 50 +steps in under five seconds. A 3090 or A10G will do it. +Install in a virtual Python environment: ``` -conda create --name riffusion-inference python=3.9 -conda activate riffusion-inference +conda create --name riffusion python=3.9 +conda activate riffusion python -m pip install -r requirements.txt ``` -If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12). +If torchaudio has no audio backend, see +[this issue](https://github.com/riffusion/riffusion/issues/12). -You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 – you'll need ffmpeg or libav. +You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 – +you'll need to install ffmpeg with `suod apt-get install ffmpeg` or `brew install ffmpeg`. Guides: -* [CUDA help](https://github.com/riffusion/riffusion/issues/3) * [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/) +## Backends + +#### CUDA +`cuda` is the recommended and most performant backend. + +To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the +[install guide](https://pytorch.org/get-started/locally/) or +[stable wheels](https://download.pytorch.org/whl/torch_stable.html). Check with: + +```python3 +import torch +torch.cuda.is_available() +``` + +Also see [this issue](https://github.com/riffusion/riffusion/issues/3) for help. + +#### CPU +`cpu` works but is quite slow. + +#### MPS +The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU, +particularly for audio processing. You may need to set +PYTORCH_ENABLE_MPS_FALLBACK=1. + +In addition, this backend is not deterministic. + +## Command-line interface + +Riffusion comes with a command line interface for performing common tasks. + +See available commands: +``` +python -m riffusion-cli -h +``` + +Get help for a specific command: +``` +python -m riffusion.cli image-to-audio -h +``` + +Execute: +``` +python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav +``` + +## Streamlit playground + +Riffusion also has a streamlit app for interactive use and exploration. +This app is called the Riffusion Playground. + +Run with: +``` +python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --bro +wser.serverPort 8501 +``` + +And access at http://127.0.0.1:8501/ + ## Run the model server -Start the Flask server: + +Riffusion can be run as a flask server that provides inference via API. Run with: + ``` python -m riffusion.server --host 127.0.0.1 --port 3013 ``` You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format. +Use the `--device` argument to specify the torch device to use. + The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request. Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API): @@ -79,15 +158,6 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe } ``` -Use the `--device` argument to specify the torch device to use. - -`cuda` is recommended. - -`cpu` works but is quite slow. - -`mps` is supported for inference, but some operations fall back to CPU. You may need to set -PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic. - ## Test Tests live in the `test/` directory and are implemented with `unittest`. @@ -125,15 +195,6 @@ These are configured in `pyproject.toml`. The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR. -## Citation +CI is run through GitHub Actions from `.github/workflows/ci.yml`. -If you build on this work, please cite it as follows: - -``` -@article{Forsgren_Martiros_2022, - author = {Forsgren, Seth* and Martiros, Hayk*}, - title = {{Riffusion - Stable diffusion for real-time music generation}}, - url = {https://riffusion.com/about}, - year = {2022} -} -``` +Contributions are welcome through opening pull requests. diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index c696085..7d328fb 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -12,20 +12,22 @@ def render_text_to_audio() -> None: """ prompt = st.text_input("Prompt") negative_prompt = st.text_input("Negative prompt") - seed = T.cast(int, st.sidebar.number_input("Seed", value=42)) - num_inference_steps = T.cast(int, st.sidebar.number_input("Inference steps", value=50)) - width = T.cast(int, st.sidebar.number_input("Width", value=512)) - height = T.cast(int, st.sidebar.number_input("Height", value=512)) - guidance = st.sidebar.number_input( - "Guidance", value=7.0, help="How much the model listens to the text prompt" - ) + + device = streamlit_util.select_device(st.sidebar) + + with st.sidebar.expander("Text to Audio Params", expanded=True): + seed = T.cast(int, st.number_input("Seed", value=42)) + num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50)) + width = T.cast(int, st.number_input("Width", value=512)) + height = T.cast(int, st.number_input("Height", value=512)) + guidance = st.number_input( + "Guidance", value=7.0, help="How much the model listens to the text prompt" + ) if not prompt: st.info("Enter a prompt") return - device = streamlit_util.select_device(st.sidebar) - image = streamlit_util.run_txt2img( prompt=prompt, num_inference_steps=num_inference_steps, diff --git a/riffusion/streamlit/pages/text_to_audio_batch.py b/riffusion/streamlit/pages/text_to_audio_batch.py new file mode 100644 index 0000000..139df62 --- /dev/null +++ b/riffusion/streamlit/pages/text_to_audio_batch.py @@ -0,0 +1,97 @@ +import json +import typing as T +from pathlib import Path + +import streamlit as st + +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.streamlit import util as streamlit_util + + +def render_text_to_audio_batch() -> None: + """ + Render audio from text in batches, reading from a text file. + """ + json_file = st.file_uploader("JSON file", type=["json"]) + if json_file is None: + st.info("Upload a JSON file of prompts") + return + + data = json.loads(json_file.read()) + + with st.expander("Data", expanded=False): + st.json(data) + + params = data["params"] + entries = data["entries"] + + device = streamlit_util.select_device(st.sidebar) + + show_images = st.sidebar.checkbox("Show Images", True) + + # Optionally specify an output directory + output_dir = st.sidebar.text_input("Output Directory", "") + output_path: T.Optional[Path] = None + if output_dir: + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + for i, entry in enumerate(entries): + st.write(f"### Entry {i + 1} / {len(entries)}") + + st.write(f"Prompt: {entry['prompt']}") + + negative_prompt = entry.get("negative_prompt", None) + st.write(f"Negative prompt: {negative_prompt}") + + image = streamlit_util.run_txt2img( + prompt=entry["prompt"], + negative_prompt=negative_prompt, + seed=params["seed"], + num_inference_steps=params["num_inference_steps"], + guidance=params["guidance"], + width=params["width"], + height=params["height"], + device=device, + ) + + if show_images: + st.image(image) + + # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained + p_spectrogram = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + ) + + output_format = "mp3" + audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + image=image, + params=p_spectrogram, + device=device, + output_format=output_format, + ) + st.audio(audio_bytes) + + if output_path: + prompt_slug = entry["prompt"].replace(" ", "_") + negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_") + + image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg" + image.save(image_path, format="JPEG") + entry["image_path"] = str(image_path) + + audio_path = ( + output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}" + ) + audio_path.write_bytes(audio_bytes.getbuffer()) + entry["audio_path"] = str(audio_path) + + if output_path: + output_json_path = output_path / "index.json" + output_json_path.write_text(json.dumps(data, indent=4)) + st.info(f"Output written to {str(output_path)}") + + +if __name__ == "__main__": + render_text_to_audio_batch()