diff --git a/riffusion/streamlit/pages/image_to_audio.py b/riffusion/streamlit/pages/image_to_audio.py index f890c5e..eb57211 100644 --- a/riffusion/streamlit/pages/image_to_audio.py +++ b/riffusion/streamlit/pages/image_to_audio.py @@ -1,11 +1,7 @@ -import io import streamlit as st from PIL import Image -from riffusion.spectrogram_image_converter import SpectrogramImageConverter -from riffusion.spectrogram_params import SpectrogramParams -from riffusion.streamlit import util as streamlit_util from riffusion.util.image_util import exif_from_image @@ -26,13 +22,13 @@ def render_image_to_audio() -> None: st.write("Exif data:") st.write(exif) - device = "cuda" + # device = "cuda" - try: - params = SpectrogramParams.from_exif(exif=image.getexif()) - except KeyError: - st.warning("Could not find spectrogram parameters in exif data. Using defaults.") - params = SpectrogramParams() + # try: + # params = SpectrogramParams.from_exif(exif=image.getexif()) + # except KeyError: + # st.warning("Could not find spectrogram parameters in exif data. Using defaults.") + # params = SpectrogramParams() # segment = streamlit_util.audio_from_spectrogram_image( # image=image, diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 9e2c129..c696085 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -1,101 +1,32 @@ -import io -from pathlib import Path +import typing as T -import dacite -from diffusers import StableDiffusionPipeline import streamlit as st -import torch -from PIL import Image -from riffusion.datatypes import InferenceInput -from riffusion.spectrogram_image_converter import SpectrogramImageConverter from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util -@st.experimental_singleton -def load_stable_diffusion_pipeline( - checkpoint: str = "riffusion/riffusion-model-v1", - device: str = "cuda", - dtype: torch.dtype = torch.float16, -) -> StableDiffusionPipeline: - """ - Load the riffusion pipeline. - """ - if device == "cpu" or device.lower().startswith("mps"): - print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") - dtype = torch.float32 - - return StableDiffusionPipeline.from_pretrained( - checkpoint, - revision="main", - torch_dtype=dtype, - safety_checker=lambda images, **kwargs: (images, False), - ).to(device) - - -@st.experimental_memo -def run_txt2img( - prompt: str, - num_inference_steps: int, - guidance: float, - negative_prompt: str, - seed: int, - width: int, - height: int, - device: str = "cuda", -) -> Image.Image: - """ - Run the text to image pipeline with caching. - """ - pipeline = load_stable_diffusion_pipeline(device=device) - - generator = torch.Generator(device="cpu").manual_seed(seed) - - output = pipeline( - prompt=prompt, - num_inference_steps=num_inference_steps, - guidance_scale=guidance, - negative_prompt=negative_prompt or None, - generator=generator, - width=width, - height=height, - ) - - return output["images"][0] - - def render_text_to_audio() -> None: """ Render audio from text. """ prompt = st.text_input("Prompt") - if not prompt: - st.info("Enter a prompt") - return - negative_prompt = st.text_input("Negative prompt") - seed = st.sidebar.number_input("Seed", value=42) - num_inference_steps = st.sidebar.number_input("Inference steps", value=20) - width = st.sidebar.number_input("Width", value=512) - height = st.sidebar.number_input("Height", value=512) + seed = T.cast(int, st.sidebar.number_input("Seed", value=42)) + num_inference_steps = T.cast(int, st.sidebar.number_input("Inference steps", value=50)) + width = T.cast(int, st.sidebar.number_input("Width", value=512)) + height = T.cast(int, st.sidebar.number_input("Height", value=512)) guidance = st.sidebar.number_input( "Guidance", value=7.0, help="How much the model listens to the text prompt" ) - default_device = "cpu" - if torch.cuda.is_available(): - default_device = "cuda" - elif torch.backends.mps.is_available(): - default_device = "mps" + if not prompt: + st.info("Enter a prompt") + return - device_options = ["cuda", "cpu", "mps"] - device = st.sidebar.selectbox( - "Device", options=device_options, index=device_options.index(default_device) - ) - assert device is not None + device = streamlit_util.select_device(st.sidebar) - image = run_txt2img( + image = streamlit_util.run_txt2img( prompt=prompt, num_inference_steps=num_inference_steps, guidance=guidance, @@ -105,7 +36,6 @@ def render_text_to_audio() -> None: height=height, device=device, ) - st.image(image) # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained @@ -114,16 +44,13 @@ def render_text_to_audio() -> None: max_frequency=10000, ) - segment = streamlit_util.audio_from_spectrogram_image( + audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( image=image, params=params, device=device, + output_format="mp3", ) - - mp3_bytes = io.BytesIO() - segment.export(mp3_bytes, format="mp3") - mp3_bytes.seek(0) - st.audio(mp3_bytes) + st.audio(audio_bytes) if __name__ == "__main__": diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index e0ad5a2..2d2bb0f 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -1,9 +1,12 @@ """ Streamlit utilities (mostly cached wrappers around riffusion code). """ +import io +import typing as T -import pydub import streamlit as st +import torch +from diffusers import StableDiffusionPipeline from PIL import Image from riffusion.riffusion_pipeline import RiffusionPipeline @@ -26,6 +29,63 @@ def load_riffusion_checkpoint( device=device, ) + +@st.experimental_singleton +def load_stable_diffusion_pipeline( + checkpoint: str = "riffusion/riffusion-model-v1", + device: str = "cuda", + dtype: torch.dtype = torch.float16, +) -> StableDiffusionPipeline: + """ + Load the riffusion pipeline. + + TODO(hayk): Merge this into RiffusionPipeline to just load one model. + """ + if device == "cpu" or device.lower().startswith("mps"): + print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") + dtype = torch.float32 + + return StableDiffusionPipeline.from_pretrained( + checkpoint, + revision="main", + torch_dtype=dtype, + safety_checker=lambda images, **kwargs: (images, False), + ).to(device) + + + +@st.experimental_memo +def run_txt2img( + prompt: str, + num_inference_steps: int, + guidance: float, + negative_prompt: str, + seed: int, + width: int, + height: int, + device: str = "cuda", +) -> Image.Image: + """ + Run the text to image pipeline with caching. + """ + pipeline = load_stable_diffusion_pipeline(device=device) + + generator_device = "cpu" if device.lower().startswith("mps") else device + generator = torch.Generator(device=generator_device).manual_seed(seed) + + output = pipeline( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance, + negative_prompt=negative_prompt or None, + generator=generator, + width=width, + height=height, + ) + + return output["images"][0] + + # class CachedSpectrogramImageConverter: # def __init__(self, params: SpectrogramParams, device: str = "cuda"): @@ -54,13 +114,39 @@ def spectrogram_image_converter( @st.experimental_memo -def audio_from_spectrogram_image( +def audio_bytes_from_spectrogram_image( image: Image.Image, params: SpectrogramParams, device: str = "cuda", -) -> pydub.AudioSegment: + output_format: str = "mp3", +) -> io.BytesIO: converter = spectrogram_image_converter(params=params, device=device) - return converter.audio_from_spectrogram_image(image) + segment = converter.audio_from_spectrogram_image(image) + + audio_bytes = io.BytesIO() + segment.export(audio_bytes, format=output_format) + audio_bytes.seek(0) + + return audio_bytes + + +def select_device(container: T.Any = st.sidebar) -> str: + """ + Dropdown to select a torch device, with an intelligent default. + """ + default_device = "cpu" + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.backends.mps.is_available(): + default_device = "mps" + + device_options = ["cuda", "cpu", "mps"] + device = st.sidebar.selectbox( + "Device", options=device_options, index=device_options.index(default_device) + ) + assert device is not None + + return device # @st.experimental_memo