Improve text to audio

Topic: streamlit_app
This commit is contained in:
Hayk Martiros 2022-12-26 21:03:30 -08:00
parent 39dc247a1d
commit 420674148a
3 changed files with 109 additions and 100 deletions

View File

@ -1,11 +1,7 @@
import io
import streamlit as st
from PIL import Image
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.util.image_util import exif_from_image
@ -26,13 +22,13 @@ def render_image_to_audio() -> None:
st.write("Exif data:")
st.write(exif)
device = "cuda"
# device = "cuda"
try:
params = SpectrogramParams.from_exif(exif=image.getexif())
except KeyError:
st.warning("Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
# try:
# params = SpectrogramParams.from_exif(exif=image.getexif())
# except KeyError:
# st.warning("Could not find spectrogram parameters in exif data. Using defaults.")
# params = SpectrogramParams()
# segment = streamlit_util.audio_from_spectrogram_image(
# image=image,

View File

@ -1,101 +1,32 @@
import io
from pathlib import Path
import typing as T
import dacite
from diffusers import StableDiffusionPipeline
import streamlit as st
import torch
from PIL import Image
from riffusion.datatypes import InferenceInput
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
@st.experimental_singleton
def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
return StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
@st.experimental_memo
def run_txt2img(
prompt: str,
num_inference_steps: int,
guidance: float,
negative_prompt: str,
seed: int,
width: int,
height: int,
device: str = "cuda",
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
pipeline = load_stable_diffusion_pipeline(device=device)
generator = torch.Generator(device="cpu").manual_seed(seed)
output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)
return output["images"][0]
def render_text_to_audio() -> None:
"""
Render audio from text.
"""
prompt = st.text_input("Prompt")
if not prompt:
st.info("Enter a prompt")
return
negative_prompt = st.text_input("Negative prompt")
seed = st.sidebar.number_input("Seed", value=42)
num_inference_steps = st.sidebar.number_input("Inference steps", value=20)
width = st.sidebar.number_input("Width", value=512)
height = st.sidebar.number_input("Height", value=512)
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
num_inference_steps = T.cast(int, st.sidebar.number_input("Inference steps", value=50))
width = T.cast(int, st.sidebar.number_input("Width", value=512))
height = T.cast(int, st.sidebar.number_input("Height", value=512))
guidance = st.sidebar.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
default_device = "cpu"
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
if not prompt:
st.info("Enter a prompt")
return
device_options = ["cuda", "cpu", "mps"]
device = st.sidebar.selectbox(
"Device", options=device_options, index=device_options.index(default_device)
)
assert device is not None
device = streamlit_util.select_device(st.sidebar)
image = run_txt2img(
image = streamlit_util.run_txt2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance=guidance,
@ -105,7 +36,6 @@ def render_text_to_audio() -> None:
height=height,
device=device,
)
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
@ -114,16 +44,13 @@ def render_text_to_audio() -> None:
max_frequency=10000,
)
segment = streamlit_util.audio_from_spectrogram_image(
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format="mp3",
)
mp3_bytes = io.BytesIO()
segment.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
st.audio(mp3_bytes)
st.audio(audio_bytes)
if __name__ == "__main__":

View File

@ -1,9 +1,12 @@
"""
Streamlit utilities (mostly cached wrappers around riffusion code).
"""
import io
import typing as T
import pydub
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
from riffusion.riffusion_pipeline import RiffusionPipeline
@ -26,6 +29,63 @@ def load_riffusion_checkpoint(
device=device,
)
@st.experimental_singleton
def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
return StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
@st.experimental_memo
def run_txt2img(
prompt: str,
num_inference_steps: int,
guidance: float,
negative_prompt: str,
seed: int,
width: int,
height: int,
device: str = "cuda",
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
pipeline = load_stable_diffusion_pipeline(device=device)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)
return output["images"][0]
# class CachedSpectrogramImageConverter:
# def __init__(self, params: SpectrogramParams, device: str = "cuda"):
@ -54,13 +114,39 @@ def spectrogram_image_converter(
@st.experimental_memo
def audio_from_spectrogram_image(
def audio_bytes_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
device: str = "cuda",
) -> pydub.AudioSegment:
output_format: str = "mp3",
) -> io.BytesIO:
converter = spectrogram_image_converter(params=params, device=device)
return converter.audio_from_spectrogram_image(image)
segment = converter.audio_from_spectrogram_image(image)
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=output_format)
audio_bytes.seek(0)
return audio_bytes
def select_device(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a torch device, with an intelligent default.
"""
default_device = "cpu"
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
device_options = ["cuda", "cpu", "mps"]
device = st.sidebar.selectbox(
"Device", options=device_options, index=device_options.index(default_device)
)
assert device is not None
return device
# @st.experimental_memo