Merge pull request #40 from riffusion/hayk.mart/revup/main/streamlit_app

Streamlit app for interactive use of the model
This commit is contained in:
Hayk Martiros 2022-12-27 00:43:07 -08:00 committed by GitHub
commit 7b55a966ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 814 additions and 41 deletions

137
README.md
View File

@ -1,51 +1,130 @@
# Riffusion
Riffusion is a technique for real-time music and audio generation with stable diffusion.
Riffusion is a library for real-time music and audio generation with stable diffusion.
Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
* Inference server: https://github.com/riffusion/riffusion
This repository contains the core riffusion image and audio processing code and supporting apps,
including:
* diffusion pipeline that performs prompt interpolation combined with image conditioning
* package for (approximately) converting between spectrogram images and audio clips
* interactive playground using streamlit
* command-line tool for common tasks
* flask server to provide model inference via API
* various third party integrations
* test suite
Related repositories:
* Web app: https://github.com/riffusion/riffusion-app
* Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
This repository contains the Python backend does the model inference and audio processing, including:
## Citation
* a diffusers pipeline that performs prompt interpolation combined with image conditioning
* a module for (approximately) converting between spectrograms and waveforms
* a flask server to provide model inference via API to the next.js app
* a model template titled baseten.py for deploying as a Truss
If you build on this work, please cite it as follows:
```
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}
```
## Install
Tested with Python 3.9 and diffusers 0.9.0.
Tested with Python 3.9 + 3.10 and diffusers 0.9.0.
To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds.
You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
To run this model in real time, you need a GPU that can run stable diffusion with approximately 50
steps in under five seconds. A 3090 or A10G will do it.
Install in a virtual Python environment:
```
conda create --name riffusion-inference python=3.9
conda activate riffusion-inference
conda create --name riffusion python=3.9
conda activate riffusion
python -m pip install -r requirements.txt
```
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
If torchaudio has no audio backend, see
[this issue](https://github.com/riffusion/riffusion/issues/12).
You can open and save WAV files with pure python. For opening and saving non-wav files like mp3 you'll need ffmpeg or libav.
You can open and save WAV files with pure python. For opening and saving non-wav files like mp3
you'll need to install ffmpeg with `suod apt-get install ffmpeg` or `brew install ffmpeg`.
Guides:
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
## Backends
#### CUDA
`cuda` is the recommended and most performant backend.
To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
[install guide](https://pytorch.org/get-started/locally/) or
[stable wheels](https://download.pytorch.org/whl/torch_stable.html). Check with:
```python3
import torch
torch.cuda.is_available()
```
Also see [this issue](https://github.com/riffusion/riffusion/issues/3) for help.
#### CPU
`cpu` works but is quite slow.
#### MPS
The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
particularly for audio processing. You may need to set
PYTORCH_ENABLE_MPS_FALLBACK=1.
In addition, this backend is not deterministic.
## Command-line interface
Riffusion comes with a command line interface for performing common tasks.
See available commands:
```
python -m riffusion-cli -h
```
Get help for a specific command:
```
python -m riffusion.cli image-to-audio -h
```
Execute:
```
python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
```
## Streamlit playground
Riffusion also has a streamlit app for interactive use and exploration.
This app is called the Riffusion Playground.
Run with:
```
python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --bro
wser.serverPort 8501
```
And access at http://127.0.0.1:8501/
## Run the model server
Start the Flask server:
Riffusion can be run as a flask server that provides inference via API. Run with:
```
python -m riffusion.server --host 127.0.0.1 --port 3013
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
Use the `--device` argument to specify the torch device to use.
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
@ -79,15 +158,6 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
}
```
Use the `--device` argument to specify the torch device to use.
`cuda` is recommended.
`cpu` works but is quite slow.
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
## Test
Tests live in the `test/` directory and are implemented with `unittest`.
@ -106,7 +176,7 @@ To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
```
To run a single test case:
To run a single test case within a test:
```
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
```
@ -125,15 +195,6 @@ These are configured in `pyproject.toml`.
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
## Citation
CI is run through GitHub Actions from `.github/workflows/ci.yml`.
If you build on this work, please cite it as follows:
```
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}
```
Contributions are welcome through opening pull requests.

View File

@ -116,9 +116,6 @@ def sample_clips(
if not output_dir_path.exists():
output_dir_path.mkdir(parents=True)
# TODO(hayk): Might be a lot easier with pydub
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)

View File

@ -0,0 +1,3 @@
# streamlit
This package is an interactive streamlit app for riffusion.

View File

View File

@ -0,0 +1,58 @@
import dataclasses
import streamlit as st
from PIL import Image
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.util.image_util import exif_from_image
def render_image_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":musical_keyboard: Image to Audio")
st.write(
"""
Reconstruct audio from spectrogram images.
"""
)
device = streamlit_util.select_device(st.sidebar)
image_file = st.file_uploader(
"Upload a file",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed",
)
if not image_file:
st.info("Upload an image file to get started")
return
image = Image.open(image_file)
st.image(image)
with st.expander("Image metadata", expanded=False):
exif = exif_from_image(image)
st.json(exif)
try:
params = SpectrogramParams.from_exif(exif=image.getexif())
except KeyError:
st.info("Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
with st.expander("Spectrogram Parameters", expanded=False):
st.json(dataclasses.asdict(params))
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image.copy(),
params=params,
device=device,
output_format="mp3",
)
st.audio(audio_bytes)
if __name__ == "__main__":
render_image_to_audio()

View File

@ -0,0 +1,197 @@
import dataclasses
import io
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
from PIL import Image
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_interpolation_demo() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":performing_arts: Interpolation")
st.write(
"""
Interpolate between prompts in the latent space.
"""
)
# Sidebar params
device = streamlit_util.select_device(st.sidebar)
num_interpolation_steps = T.cast(
int,
st.sidebar.number_input(
"Interpolation steps",
value=4,
min_value=1,
max_value=20,
help="Number of model generations between the two prompts. Controls the duration.",
),
)
num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
init_image_name = st.sidebar.selectbox(
"Seed image",
# TODO(hayk): Read from directory
options=["og_beat", "agile", "marim", "motorway", "vibes"],
index=0,
help="Which seed image to use for img2img",
)
assert init_image_name is not None
show_individual_outputs = st.sidebar.checkbox(
"Show individual outputs",
value=False,
help="Show each model output",
)
show_images = st.sidebar.checkbox(
"Show individual images",
value=False,
help="Show each generated image",
)
# Prompt inputs A and B in two columns
left, right = st.columns(2)
with left.expander("Input A", expanded=True):
prompt_input_a = get_prompt_inputs(key="a")
with right.expander("Input B", expanded=True):
prompt_input_b = get_prompt_inputs(key="b")
if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them")
return
alphas = list(np.linspace(0, 1, num_interpolation_steps))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
# TODO(hayk): Upload your own seed image.
init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
)
init_image = Image.open(str(init_image_path)).convert("RGB")
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
image_list: T.List[Image.Image] = []
audio_bytes_list: T.List[io.BytesIO] = []
for i, alpha in enumerate(alphas):
inputs = InferenceInput(
alpha=float(alpha),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)
if i == 0:
with st.expander("Example input JSON", expanded=False):
st.json(dataclasses.asdict(inputs))
image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image,
device=device,
)
if show_individual_outputs:
st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
if show_images:
st.image(image)
st.audio(audio_bytes)
image_list.append(image)
audio_bytes_list.append(audio_bytes)
st.write("#### Final Output")
# TODO(hayk): Concatenate with better blending
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
concat_segment = audio_segments[0]
for segment in audio_segments[1:]:
concat_segment = concat_segment.append(segment, crossfade=0)
audio_bytes = io.BytesIO()
concat_segment.export(audio_bytes, format="mp3")
audio_bytes.seek(0)
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
st.audio(audio_bytes)
def get_prompt_inputs(key: str) -> PromptInput:
"""
Compute prompt inputs from widgets.
"""
prompt = st.text_input("Prompt", label_visibility="collapsed", key=f"prompt_{key}")
seed = T.cast(int, st.number_input("Seed", value=42, key=f"seed_{key}"))
denoising = st.number_input(
"Denoising", value=0.75, key=f"denoising_{key}", help="How much to modify the seed image"
)
guidance = st.number_input(
"Guidance",
value=7.0,
key=f"guidance_{key}",
help="How much the model listens to the text prompt",
)
return PromptInput(
prompt=prompt,
seed=seed,
denoising=denoising,
guidance=guidance,
)
@st.experimental_memo
def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
) -> T.Tuple[Image.Image, io.BytesIO]:
"""
Cached function for riffusion interpolation.
"""
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=None,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct from image to audio
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format="mp3",
)
return image, audio_bytes
if __name__ == "__main__":
render_interpolation_demo()

View File

@ -0,0 +1,85 @@
import tempfile
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
def render_sample_clips() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":scissors: Sample Clips")
st.write(
"""
Export short clips from an audio file.
"""
)
audio_file = st.file_uploader(
"Upload a file",
type=["wav", "mp3", "ogg"],
label_visibility="collapsed",
)
if not audio_file:
st.info("Upload an audio file to get started")
return
st.audio(audio_file)
segment = pydub.AudioSegment.from_file(audio_file)
st.write(
" \n".join(
[
f"**Duration**: {segment.duration_seconds:.3f} seconds",
f"**Channels**: {segment.channels}",
f"**Sample rate**: {segment.frame_rate} Hz",
f"**Sample width**: {segment.sample_width} bytes",
]
)
)
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000))
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3))
extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"])
assert extension is not None
# Optionally specify an output directory
output_dir = st.text_input("Output Directory")
if not output_dir:
tmp_dir = tempfile.mkdtemp(prefix="sample_clips_")
st.info(f"Specify an output directory. Suggested: `{tmp_dir}`")
return
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
if seed >= 0:
np.random.seed(seed)
if export_as_mono and segment.channels > 1:
segment = segment.set_channels(1)
# TODO(hayk): Share code with riffusion.cli.sample_clips.
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
clip_path = output_path / clip_name
clip.export(clip_path, format=extension)
st.audio(str(clip_path))
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
if __name__ == "__main__":
render_sample_clips()

View File

@ -0,0 +1,66 @@
import typing as T
import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_text_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":pencil2: Text to Audio")
st.write(
"""
Generate audio from text prompts. \nRuns the model directly without a seed image or
interpolation.
"""
)
device = streamlit_util.select_device(st.sidebar)
prompt = st.text_input("Prompt")
negative_prompt = st.text_input("Negative prompt")
with st.sidebar.expander("Text to Audio Params", expanded=True):
seed = T.cast(int, st.number_input("Seed", value=42))
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
width = T.cast(int, st.number_input("Width", value=512))
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
if not prompt:
st.info("Enter a prompt")
return
image = streamlit_util.run_txt2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance=guidance,
negative_prompt=negative_prompt,
seed=seed,
width=width,
height=512,
device=device,
)
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format="mp3",
)
st.audio(audio_bytes)
if __name__ == "__main__":
render_text_to_audio()

View File

@ -0,0 +1,138 @@
import json
import typing as T
from pathlib import Path
import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
# Example input json file to process in batch
EXAMPLE_INPUT = """
{
"params": {
"seed": 42,
"num_inference_steps": 50,
"guidance": 7.0,
"width": 512,
},
"entries": [
{
"prompt": "Church bells"
},
{
"prompt": "electronic beats",
"negative_prompt": "drums"
},
{
"prompt": "classical violin concerto"
}
]
}
"""
def render_text_to_audio_batch() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":scroll: Text to Audio Batch")
st.write(
"""
Generate audio in batch from a JSON file of text prompts. \nThe input
file contains a global params block and a list of entries with positive and negative
prompts.
"""
)
device = streamlit_util.select_device(st.sidebar)
# Upload a JSON file
json_file = st.file_uploader(
"JSON file",
type=["json"],
label_visibility="collapsed",
)
# Handle the null case
if json_file is None:
st.info("Upload a JSON file containing params and prompts")
with st.expander("Example inputs.json", expanded=False):
st.code(EXAMPLE_INPUT)
return
# Read in and print it
data = json.loads(json_file.read())
with st.expander("Input Data", expanded=False):
st.json(data)
params = data["params"]
entries = data["entries"]
show_images = st.sidebar.checkbox("Show Images", False)
# Optionally specify an output directory
output_dir = st.sidebar.text_input("Output Directory", "")
output_path: T.Optional[Path] = None
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for i, entry in enumerate(entries):
st.write(f"#### Entry {i + 1} / {len(entries)}")
negative_prompt = entry.get("negative_prompt", None)
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
image = streamlit_util.run_txt2img(
prompt=entry["prompt"],
negative_prompt=negative_prompt,
seed=params.get("seed", 42),
num_inference_steps=params.get("num_inference_steps", 50),
guidance=params.get("guidance", 7.0),
width=params.get("width", 512),
height=512,
device=device,
)
if show_images:
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
p_spectrogram = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
output_format = "mp3"
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=p_spectrogram,
device=device,
output_format=output_format,
)
st.audio(audio_bytes)
if output_path:
prompt_slug = entry["prompt"].replace(" ", "_")
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
image.save(image_path, format="JPEG")
entry["image_path"] = str(image_path)
audio_path = (
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
)
audio_path.write_bytes(audio_bytes.getbuffer())
entry["audio_path"] = str(audio_path)
if output_path:
output_json_path = output_path / "index.json"
output_json_path.write_text(json.dumps(data, indent=4))
st.info(f"Output written to {str(output_path)}")
else:
st.info("Enter output directory in sidebar to save to disk")
if __name__ == "__main__":
render_text_to_audio_batch()

View File

@ -0,0 +1,37 @@
import streamlit as st
def render_main():
st.set_page_config(layout="wide", page_icon="🎸")
st.header(":guitar: Riffusion Playground")
st.write("Interactive app for common riffusion tasks.")
left, right = st.columns(2)
with left:
create_link(":performing_arts: Interpolation", "/interpolation")
st.write("Interpolate between prompts in the latent space.")
create_link(":pencil2: Text to Audio", "/text_to_audio")
st.write("Generate audio from text prompts.")
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
st.write("Generate audio in batch from a JSON file of text prompts.")
with right:
create_link(":scissors: Sample Clips", "/sample_clips")
st.write("Export short clips from an audio file.")
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
st.write("Reconstruct audio from spectrogram images.")
def create_link(name: str, url: str) -> None:
st.markdown(
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
unsafe_allow_html=True,
)
if __name__ == "__main__":
render_main()

131
riffusion/streamlit/util.py Normal file
View File

@ -0,0 +1,131 @@
"""
Streamlit utilities (mostly cached wrappers around riffusion code).
"""
import io
import typing as T
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
# TODO(hayk): Add URL params
@st.experimental_singleton
def load_riffusion_checkpoint(
checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
device: str = "cuda",
) -> RiffusionPipeline:
"""
Load the riffusion pipeline.
"""
return RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=not no_traced_unet,
device=device,
)
@st.experimental_singleton
def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
return StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
@st.experimental_memo
def run_txt2img(
prompt: str,
num_inference_steps: int,
guidance: float,
negative_prompt: str,
seed: int,
width: int,
height: int,
device: str = "cuda",
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
pipeline = load_stable_diffusion_pipeline(device=device)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)
return output["images"][0]
@st.experimental_singleton
def spectrogram_image_converter(
params: SpectrogramParams,
device: str = "cuda",
) -> SpectrogramImageConverter:
return SpectrogramImageConverter(params=params, device=device)
@st.experimental_memo
def audio_bytes_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
device: str = "cuda",
output_format: str = "mp3",
) -> io.BytesIO:
converter = spectrogram_image_converter(params=params, device=device)
segment = converter.audio_from_spectrogram_image(image)
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=output_format)
audio_bytes.seek(0)
return audio_bytes
def select_device(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a torch device, with an intelligent default.
"""
default_device = "cpu"
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
device_options = ["cuda", "cpu", "mps"]
device = st.sidebar.selectbox(
"Device", options=device_options, index=device_options.index(default_device)
)
assert device is not None
return device