Merge pull request #40 from riffusion/hayk.mart/revup/main/streamlit_app
Streamlit app for interactive use of the model
This commit is contained in:
commit
7b55a966ba
137
README.md
137
README.md
|
@ -1,51 +1,130 @@
|
|||
# Riffusion
|
||||
|
||||
Riffusion is a technique for real-time music and audio generation with stable diffusion.
|
||||
Riffusion is a library for real-time music and audio generation with stable diffusion.
|
||||
|
||||
Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
|
||||
|
||||
* Inference server: https://github.com/riffusion/riffusion
|
||||
This repository contains the core riffusion image and audio processing code and supporting apps,
|
||||
including:
|
||||
|
||||
* diffusion pipeline that performs prompt interpolation combined with image conditioning
|
||||
* package for (approximately) converting between spectrogram images and audio clips
|
||||
* interactive playground using streamlit
|
||||
* command-line tool for common tasks
|
||||
* flask server to provide model inference via API
|
||||
* various third party integrations
|
||||
* test suite
|
||||
|
||||
Related repositories:
|
||||
* Web app: https://github.com/riffusion/riffusion-app
|
||||
* Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
|
||||
|
||||
This repository contains the Python backend does the model inference and audio processing, including:
|
||||
## Citation
|
||||
|
||||
* a diffusers pipeline that performs prompt interpolation combined with image conditioning
|
||||
* a module for (approximately) converting between spectrograms and waveforms
|
||||
* a flask server to provide model inference via API to the next.js app
|
||||
* a model template titled baseten.py for deploying as a Truss
|
||||
If you build on this work, please cite it as follows:
|
||||
|
||||
```
|
||||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
Tested with Python 3.9 and diffusers 0.9.0.
|
||||
Tested with Python 3.9 + 3.10 and diffusers 0.9.0.
|
||||
|
||||
To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds.
|
||||
|
||||
You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
|
||||
To run this model in real time, you need a GPU that can run stable diffusion with approximately 50
|
||||
steps in under five seconds. A 3090 or A10G will do it.
|
||||
|
||||
Install in a virtual Python environment:
|
||||
```
|
||||
conda create --name riffusion-inference python=3.9
|
||||
conda activate riffusion-inference
|
||||
conda create --name riffusion python=3.9
|
||||
conda activate riffusion
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
|
||||
If torchaudio has no audio backend, see
|
||||
[this issue](https://github.com/riffusion/riffusion/issues/12).
|
||||
|
||||
You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 – you'll need ffmpeg or libav.
|
||||
You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 –
|
||||
you'll need to install ffmpeg with `suod apt-get install ffmpeg` or `brew install ffmpeg`.
|
||||
|
||||
Guides:
|
||||
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
|
||||
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
|
||||
|
||||
## Backends
|
||||
|
||||
#### CUDA
|
||||
`cuda` is the recommended and most performant backend.
|
||||
|
||||
To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
|
||||
[install guide](https://pytorch.org/get-started/locally/) or
|
||||
[stable wheels](https://download.pytorch.org/whl/torch_stable.html). Check with:
|
||||
|
||||
```python3
|
||||
import torch
|
||||
torch.cuda.is_available()
|
||||
```
|
||||
|
||||
Also see [this issue](https://github.com/riffusion/riffusion/issues/3) for help.
|
||||
|
||||
#### CPU
|
||||
`cpu` works but is quite slow.
|
||||
|
||||
#### MPS
|
||||
The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
|
||||
particularly for audio processing. You may need to set
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1.
|
||||
|
||||
In addition, this backend is not deterministic.
|
||||
|
||||
## Command-line interface
|
||||
|
||||
Riffusion comes with a command line interface for performing common tasks.
|
||||
|
||||
See available commands:
|
||||
```
|
||||
python -m riffusion-cli -h
|
||||
```
|
||||
|
||||
Get help for a specific command:
|
||||
```
|
||||
python -m riffusion.cli image-to-audio -h
|
||||
```
|
||||
|
||||
Execute:
|
||||
```
|
||||
python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
|
||||
```
|
||||
|
||||
## Streamlit playground
|
||||
|
||||
Riffusion also has a streamlit app for interactive use and exploration.
|
||||
This app is called the Riffusion Playground.
|
||||
|
||||
Run with:
|
||||
```
|
||||
python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --bro
|
||||
wser.serverPort 8501
|
||||
```
|
||||
|
||||
And access at http://127.0.0.1:8501/
|
||||
|
||||
## Run the model server
|
||||
Start the Flask server:
|
||||
|
||||
Riffusion can be run as a flask server that provides inference via API. Run with:
|
||||
|
||||
```
|
||||
python -m riffusion.server --host 127.0.0.1 --port 3013
|
||||
```
|
||||
|
||||
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
|
||||
|
||||
Use the `--device` argument to specify the torch device to use.
|
||||
|
||||
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
|
||||
|
||||
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
|
||||
|
@ -79,15 +158,6 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
|
|||
}
|
||||
```
|
||||
|
||||
Use the `--device` argument to specify the torch device to use.
|
||||
|
||||
`cuda` is recommended.
|
||||
|
||||
`cpu` works but is quite slow.
|
||||
|
||||
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
|
||||
|
||||
## Test
|
||||
Tests live in the `test/` directory and are implemented with `unittest`.
|
||||
|
||||
|
@ -106,7 +176,7 @@ To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
|
|||
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
|
||||
```
|
||||
|
||||
To run a single test case:
|
||||
To run a single test case within a test:
|
||||
```
|
||||
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
|
||||
```
|
||||
|
@ -125,15 +195,6 @@ These are configured in `pyproject.toml`.
|
|||
|
||||
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
|
||||
|
||||
## Citation
|
||||
CI is run through GitHub Actions from `.github/workflows/ci.yml`.
|
||||
|
||||
If you build on this work, please cite it as follows:
|
||||
|
||||
```
|
||||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
Contributions are welcome through opening pull requests.
|
||||
|
|
|
@ -116,9 +116,6 @@ def sample_clips(
|
|||
if not output_dir_path.exists():
|
||||
output_dir_path.mkdir(parents=True)
|
||||
|
||||
# TODO(hayk): Might be a lot easier with pydub
|
||||
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
|
||||
|
||||
segment_duration_ms = int(segment.duration_seconds * 1000)
|
||||
for i in range(num_clips):
|
||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# streamlit
|
||||
|
||||
This package is an interactive streamlit app for riffusion.
|
|
@ -0,0 +1,58 @@
|
|||
import dataclasses
|
||||
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.util.image_util import exif_from_image
|
||||
|
||||
|
||||
def render_image_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":musical_keyboard: Image to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Reconstruct audio from spectrogram images.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
image_file = st.file_uploader(
|
||||
"Upload a file",
|
||||
type=["png", "jpg", "jpeg"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not image_file:
|
||||
st.info("Upload an image file to get started")
|
||||
return
|
||||
|
||||
image = Image.open(image_file)
|
||||
st.image(image)
|
||||
|
||||
with st.expander("Image metadata", expanded=False):
|
||||
exif = exif_from_image(image)
|
||||
st.json(exif)
|
||||
|
||||
try:
|
||||
params = SpectrogramParams.from_exif(exif=image.getexif())
|
||||
except KeyError:
|
||||
st.info("Could not find spectrogram parameters in exif data. Using defaults.")
|
||||
params = SpectrogramParams()
|
||||
|
||||
with st.expander("Spectrogram Parameters", expanded=False):
|
||||
st.json(dataclasses.asdict(params))
|
||||
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image.copy(),
|
||||
params=params,
|
||||
device=device,
|
||||
output_format="mp3",
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_image_to_audio()
|
|
@ -0,0 +1,197 @@
|
|||
import dataclasses
|
||||
import io
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_interpolation_demo() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":performing_arts: Interpolation")
|
||||
st.write(
|
||||
"""
|
||||
Interpolate between prompts in the latent space.
|
||||
"""
|
||||
)
|
||||
|
||||
# Sidebar params
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
num_interpolation_steps = T.cast(
|
||||
int,
|
||||
st.sidebar.number_input(
|
||||
"Interpolation steps",
|
||||
value=4,
|
||||
min_value=1,
|
||||
max_value=20,
|
||||
help="Number of model generations between the two prompts. Controls the duration.",
|
||||
),
|
||||
)
|
||||
|
||||
num_inference_steps = T.cast(
|
||||
int,
|
||||
st.sidebar.number_input(
|
||||
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
||||
),
|
||||
)
|
||||
|
||||
init_image_name = st.sidebar.selectbox(
|
||||
"Seed image",
|
||||
# TODO(hayk): Read from directory
|
||||
options=["og_beat", "agile", "marim", "motorway", "vibes"],
|
||||
index=0,
|
||||
help="Which seed image to use for img2img",
|
||||
)
|
||||
assert init_image_name is not None
|
||||
|
||||
show_individual_outputs = st.sidebar.checkbox(
|
||||
"Show individual outputs",
|
||||
value=False,
|
||||
help="Show each model output",
|
||||
)
|
||||
show_images = st.sidebar.checkbox(
|
||||
"Show individual images",
|
||||
value=False,
|
||||
help="Show each generated image",
|
||||
)
|
||||
|
||||
# Prompt inputs A and B in two columns
|
||||
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left.expander("Input A", expanded=True):
|
||||
prompt_input_a = get_prompt_inputs(key="a")
|
||||
|
||||
with right.expander("Input B", expanded=True):
|
||||
prompt_input_b = get_prompt_inputs(key="b")
|
||||
|
||||
if not prompt_input_a.prompt or not prompt_input_b.prompt:
|
||||
st.info("Enter both prompts to interpolate between them")
|
||||
return
|
||||
|
||||
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||
st.write(f"**Alphas** : [{alphas_str}]")
|
||||
|
||||
# TODO(hayk): Upload your own seed image.
|
||||
|
||||
init_image_path = (
|
||||
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
|
||||
)
|
||||
init_image = Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
|
||||
image_list: T.List[Image.Image] = []
|
||||
audio_bytes_list: T.List[io.BytesIO] = []
|
||||
for i, alpha in enumerate(alphas):
|
||||
inputs = InferenceInput(
|
||||
alpha=float(alpha),
|
||||
num_inference_steps=num_inference_steps,
|
||||
seed_image_id="og_beat",
|
||||
start=prompt_input_a,
|
||||
end=prompt_input_b,
|
||||
)
|
||||
|
||||
if i == 0:
|
||||
with st.expander("Example input JSON", expanded=False):
|
||||
st.json(dataclasses.asdict(inputs))
|
||||
|
||||
image, audio_bytes = run_interpolation(
|
||||
inputs=inputs,
|
||||
init_image=init_image,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if show_individual_outputs:
|
||||
st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
|
||||
if show_images:
|
||||
st.image(image)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
image_list.append(image)
|
||||
audio_bytes_list.append(audio_bytes)
|
||||
|
||||
st.write("#### Final Output")
|
||||
|
||||
# TODO(hayk): Concatenate with better blending
|
||||
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
|
||||
concat_segment = audio_segments[0]
|
||||
for segment in audio_segments[1:]:
|
||||
concat_segment = concat_segment.append(segment, crossfade=0)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
concat_segment.export(audio_bytes, format="mp3")
|
||||
audio_bytes.seek(0)
|
||||
|
||||
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
|
||||
st.audio(audio_bytes)
|
||||
|
||||
|
||||
def get_prompt_inputs(key: str) -> PromptInput:
|
||||
"""
|
||||
Compute prompt inputs from widgets.
|
||||
"""
|
||||
prompt = st.text_input("Prompt", label_visibility="collapsed", key=f"prompt_{key}")
|
||||
seed = T.cast(int, st.number_input("Seed", value=42, key=f"seed_{key}"))
|
||||
denoising = st.number_input(
|
||||
"Denoising", value=0.75, key=f"denoising_{key}", help="How much to modify the seed image"
|
||||
)
|
||||
guidance = st.number_input(
|
||||
"Guidance",
|
||||
value=7.0,
|
||||
key=f"guidance_{key}",
|
||||
help="How much the model listens to the text prompt",
|
||||
)
|
||||
|
||||
return PromptInput(
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
denoising=denoising,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def run_interpolation(
|
||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
|
||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||
"""
|
||||
Cached function for riffusion interpolation.
|
||||
"""
|
||||
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
|
||||
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
init_image=init_image,
|
||||
mask_image=None,
|
||||
)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
# Reconstruct from image to audio
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
output_format="mp3",
|
||||
)
|
||||
|
||||
return image, audio_bytes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_interpolation_demo()
|
|
@ -0,0 +1,85 @@
|
|||
import tempfile
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def render_sample_clips() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":scissors: Sample Clips")
|
||||
st.write(
|
||||
"""
|
||||
Export short clips from an audio file.
|
||||
"""
|
||||
)
|
||||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload a file",
|
||||
type=["wav", "mp3", "ogg"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not audio_file:
|
||||
st.info("Upload an audio file to get started")
|
||||
return
|
||||
|
||||
st.audio(audio_file)
|
||||
|
||||
segment = pydub.AudioSegment.from_file(audio_file)
|
||||
st.write(
|
||||
" \n".join(
|
||||
[
|
||||
f"**Duration**: {segment.duration_seconds:.3f} seconds",
|
||||
f"**Channels**: {segment.channels}",
|
||||
f"**Sample rate**: {segment.frame_rate} Hz",
|
||||
f"**Sample width**: {segment.sample_width} bytes",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
|
||||
duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000))
|
||||
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
|
||||
num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3))
|
||||
extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"])
|
||||
assert extension is not None
|
||||
|
||||
# Optionally specify an output directory
|
||||
output_dir = st.text_input("Output Directory")
|
||||
if not output_dir:
|
||||
tmp_dir = tempfile.mkdtemp(prefix="sample_clips_")
|
||||
st.info(f"Specify an output directory. Suggested: `{tmp_dir}`")
|
||||
return
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if seed >= 0:
|
||||
np.random.seed(seed)
|
||||
|
||||
if export_as_mono and segment.channels > 1:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# TODO(hayk): Share code with riffusion.cli.sample_clips.
|
||||
segment_duration_ms = int(segment.duration_seconds * 1000)
|
||||
for i in range(num_clips):
|
||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
||||
|
||||
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
|
||||
|
||||
clip_path = output_path / clip_name
|
||||
clip.export(clip_path, format=extension)
|
||||
|
||||
st.audio(str(clip_path))
|
||||
|
||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_sample_clips()
|
|
@ -0,0 +1,66 @@
|
|||
import typing as T
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_text_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":pencil2: Text to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Generate audio from text prompts. \nRuns the model directly without a seed image or
|
||||
interpolation.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
prompt = st.text_input("Prompt")
|
||||
negative_prompt = st.text_input("Negative prompt")
|
||||
|
||||
with st.sidebar.expander("Text to Audio Params", expanded=True):
|
||||
seed = T.cast(int, st.number_input("Seed", value=42))
|
||||
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
|
||||
width = T.cast(int, st.number_input("Width", value=512))
|
||||
guidance = st.number_input(
|
||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance=guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
width=width,
|
||||
height=512,
|
||||
device=device,
|
||||
)
|
||||
|
||||
st.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
output_format="mp3",
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio()
|
|
@ -0,0 +1,138 @@
|
|||
import json
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
# Example input json file to process in batch
|
||||
EXAMPLE_INPUT = """
|
||||
{
|
||||
"params": {
|
||||
"seed": 42,
|
||||
"num_inference_steps": 50,
|
||||
"guidance": 7.0,
|
||||
"width": 512,
|
||||
},
|
||||
"entries": [
|
||||
{
|
||||
"prompt": "Church bells"
|
||||
},
|
||||
{
|
||||
"prompt": "electronic beats",
|
||||
"negative_prompt": "drums"
|
||||
},
|
||||
{
|
||||
"prompt": "classical violin concerto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def render_text_to_audio_batch() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":scroll: Text to Audio Batch")
|
||||
st.write(
|
||||
"""
|
||||
Generate audio in batch from a JSON file of text prompts. \nThe input
|
||||
file contains a global params block and a list of entries with positive and negative
|
||||
prompts.
|
||||
"""
|
||||
)
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
# Upload a JSON file
|
||||
json_file = st.file_uploader(
|
||||
"JSON file",
|
||||
type=["json"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
# Handle the null case
|
||||
if json_file is None:
|
||||
st.info("Upload a JSON file containing params and prompts")
|
||||
with st.expander("Example inputs.json", expanded=False):
|
||||
st.code(EXAMPLE_INPUT)
|
||||
return
|
||||
|
||||
# Read in and print it
|
||||
data = json.loads(json_file.read())
|
||||
with st.expander("Input Data", expanded=False):
|
||||
st.json(data)
|
||||
|
||||
params = data["params"]
|
||||
entries = data["entries"]
|
||||
|
||||
show_images = st.sidebar.checkbox("Show Images", False)
|
||||
|
||||
# Optionally specify an output directory
|
||||
output_dir = st.sidebar.text_input("Output Directory", "")
|
||||
output_path: T.Optional[Path] = None
|
||||
if output_dir:
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
st.write(f"#### Entry {i + 1} / {len(entries)}")
|
||||
|
||||
negative_prompt = entry.get("negative_prompt", None)
|
||||
|
||||
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=entry["prompt"],
|
||||
negative_prompt=negative_prompt,
|
||||
seed=params.get("seed", 42),
|
||||
num_inference_steps=params.get("num_inference_steps", 50),
|
||||
guidance=params.get("guidance", 7.0),
|
||||
width=params.get("width", 512),
|
||||
height=512,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if show_images:
|
||||
st.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
p_spectrogram = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
output_format = "mp3"
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image,
|
||||
params=p_spectrogram,
|
||||
device=device,
|
||||
output_format=output_format,
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
if output_path:
|
||||
prompt_slug = entry["prompt"].replace(" ", "_")
|
||||
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
|
||||
|
||||
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
||||
image.save(image_path, format="JPEG")
|
||||
entry["image_path"] = str(image_path)
|
||||
|
||||
audio_path = (
|
||||
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
||||
)
|
||||
audio_path.write_bytes(audio_bytes.getbuffer())
|
||||
entry["audio_path"] = str(audio_path)
|
||||
|
||||
if output_path:
|
||||
output_json_path = output_path / "index.json"
|
||||
output_json_path.write_text(json.dumps(data, indent=4))
|
||||
st.info(f"Output written to {str(output_path)}")
|
||||
else:
|
||||
st.info("Enter output directory in sidebar to save to disk")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio_batch()
|
|
@ -0,0 +1,37 @@
|
|||
import streamlit as st
|
||||
|
||||
|
||||
def render_main():
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
st.header(":guitar: Riffusion Playground")
|
||||
st.write("Interactive app for common riffusion tasks.")
|
||||
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left:
|
||||
create_link(":performing_arts: Interpolation", "/interpolation")
|
||||
st.write("Interpolate between prompts in the latent space.")
|
||||
|
||||
create_link(":pencil2: Text to Audio", "/text_to_audio")
|
||||
st.write("Generate audio from text prompts.")
|
||||
|
||||
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
||||
st.write("Generate audio in batch from a JSON file of text prompts.")
|
||||
|
||||
with right:
|
||||
create_link(":scissors: Sample Clips", "/sample_clips")
|
||||
st.write("Export short clips from an audio file.")
|
||||
|
||||
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
|
||||
st.write("Reconstruct audio from spectrogram images.")
|
||||
|
||||
|
||||
def create_link(name: str, url: str) -> None:
|
||||
st.markdown(
|
||||
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_main()
|
|
@ -0,0 +1,131 @@
|
|||
"""
|
||||
Streamlit utilities (mostly cached wrappers around riffusion code).
|
||||
"""
|
||||
import io
|
||||
import typing as T
|
||||
|
||||
import streamlit as st
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
# TODO(hayk): Add URL params
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_riffusion_checkpoint(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
no_traced_unet: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
"""
|
||||
return RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=checkpoint,
|
||||
use_traced_unet=not no_traced_unet,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_stable_diffusion_pipeline(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
|
||||
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
return StableDiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def run_txt2img(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance: float,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
width: int,
|
||||
height: int,
|
||||
device: str = "cuda",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Run the text to image pipeline with caching.
|
||||
"""
|
||||
pipeline = load_stable_diffusion_pipeline(device=device)
|
||||
|
||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance,
|
||||
negative_prompt=negative_prompt or None,
|
||||
generator=generator,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
return output["images"][0]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def spectrogram_image_converter(
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
) -> SpectrogramImageConverter:
|
||||
return SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def audio_bytes_from_spectrogram_image(
|
||||
image: Image.Image,
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
output_format: str = "mp3",
|
||||
) -> io.BytesIO:
|
||||
converter = spectrogram_image_converter(params=params, device=device)
|
||||
segment = converter.audio_from_spectrogram_image(image)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
segment.export(audio_bytes, format=output_format)
|
||||
audio_bytes.seek(0)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def select_device(container: T.Any = st.sidebar) -> str:
|
||||
"""
|
||||
Dropdown to select a torch device, with an intelligent default.
|
||||
"""
|
||||
default_device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
default_device = "mps"
|
||||
|
||||
device_options = ["cuda", "cpu", "mps"]
|
||||
device = st.sidebar.selectbox(
|
||||
"Device", options=device_options, index=device_options.index(default_device)
|
||||
)
|
||||
assert device is not None
|
||||
|
||||
return device
|
Loading…
Reference in New Issue