From 39dc247a1d758ea0c5688bfb7800e1ece920be16 Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Mon, 26 Dec 2022 20:01:27 -0800 Subject: [PATCH] Streamlit app for interactive use of the model Topic: streamlit_app --- riffusion/streamlit/README.md | 3 + riffusion/streamlit/__init__.py | 0 riffusion/streamlit/main.py | 25 ++++ riffusion/streamlit/pages/image_to_audio.py | 51 +++++++ .../streamlit/pages/interpolation_demo.py | 97 +++++++++++++ riffusion/streamlit/pages/text_to_audio.py | 130 ++++++++++++++++++ riffusion/streamlit/util.py | 73 ++++++++++ 7 files changed, 379 insertions(+) create mode 100644 riffusion/streamlit/README.md create mode 100644 riffusion/streamlit/__init__.py create mode 100644 riffusion/streamlit/main.py create mode 100644 riffusion/streamlit/pages/image_to_audio.py create mode 100644 riffusion/streamlit/pages/interpolation_demo.py create mode 100644 riffusion/streamlit/pages/text_to_audio.py create mode 100644 riffusion/streamlit/util.py diff --git a/riffusion/streamlit/README.md b/riffusion/streamlit/README.md new file mode 100644 index 0000000..8bc102f --- /dev/null +++ b/riffusion/streamlit/README.md @@ -0,0 +1,3 @@ +# streamlit + +This package is an interactive streamlit app for riffusion. diff --git a/riffusion/streamlit/__init__.py b/riffusion/streamlit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/riffusion/streamlit/main.py b/riffusion/streamlit/main.py new file mode 100644 index 0000000..c196417 --- /dev/null +++ b/riffusion/streamlit/main.py @@ -0,0 +1,25 @@ +import pydub +import streamlit as st + + +def run(): + st.set_page_config(layout="wide", page_icon="🎸") + + audio_file = st.file_uploader("Upload a file", type=["wav", "mp3", "ogg"]) + if not audio_file: + st.info("Upload an audio file to get started") + return + + st.audio(audio_file) + + segment = pydub.AudioSegment.from_file(audio_file) + st.write(" \n".join([ + f"**Duration**: {segment.duration_seconds:.3f} seconds", + f"**Channels**: {segment.channels}", + f"**Sample rate**: {segment.frame_rate} Hz", + f"**Sample width**: {segment.sample_width} bytes", + ])) + + +if __name__ == "__main__": + run() diff --git a/riffusion/streamlit/pages/image_to_audio.py b/riffusion/streamlit/pages/image_to_audio.py new file mode 100644 index 0000000..f890c5e --- /dev/null +++ b/riffusion/streamlit/pages/image_to_audio.py @@ -0,0 +1,51 @@ +import io + +import streamlit as st +from PIL import Image + +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.streamlit import util as streamlit_util +from riffusion.util.image_util import exif_from_image + + +def render_image_to_audio() -> None: + image_file = st.sidebar.file_uploader( + "Upload a file", + type=["png", "jpg", "jpeg"], + label_visibility="collapsed", + ) + if not image_file: + st.info("Upload an image file to get started") + return + + image = Image.open(image_file) + st.image(image) + + exif = exif_from_image(image) + st.write("Exif data:") + st.write(exif) + + device = "cuda" + + try: + params = SpectrogramParams.from_exif(exif=image.getexif()) + except KeyError: + st.warning("Could not find spectrogram parameters in exif data. Using defaults.") + params = SpectrogramParams() + + # segment = streamlit_util.audio_from_spectrogram_image( + # image=image, + # params=params, + # device=device, + # ) + + # mp3_bytes = io.BytesIO() + # segment.export(mp3_bytes, format="mp3") + # mp3_bytes.seek(0) + + # st.audio(mp3_bytes) + + +if __name__ == "__main__": + render_image_to_audio() diff --git a/riffusion/streamlit/pages/interpolation_demo.py b/riffusion/streamlit/pages/interpolation_demo.py new file mode 100644 index 0000000..1c0720e --- /dev/null +++ b/riffusion/streamlit/pages/interpolation_demo.py @@ -0,0 +1,97 @@ +import io +from pathlib import Path + +import dacite +import streamlit as st +import torch +from PIL import Image + +from riffusion.datatypes import InferenceInput +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.streamlit import util as streamlit_util + + +def render_interpolation_demo() -> None: + """ + Render audio from text. + """ + prompt = st.text_input("Prompt", label_visibility="collapsed") + if not prompt: + st.info("Enter a prompt") + return + + seed = st.sidebar.number_input("Seed", value=42) + denoising = st.sidebar.number_input("Denoising", value=0.01) + guidance = st.sidebar.number_input("Guidance", value=7.0) + num_inference_steps = st.sidebar.number_input("Inference steps", value=50) + + default_device = "cpu" + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.backends.mps.is_available(): + default_device = "mps" + + device_options = ["cuda", "cpu", "mps"] + device = st.sidebar.selectbox( + "Device", options=device_options, index=device_options.index(default_device) + ) + assert device is not None + + pipeline = streamlit_util.load_riffusion_checkpoint(device=device) + + input_dict = { + "alpha": 0.75, + "num_inference_steps": num_inference_steps, + "seed_image_id": "og_beat", + "start": { + "prompt": prompt, + "seed": seed, + "denoising": denoising, + "guidance": guidance, + }, + "end": { + "prompt": prompt, + "seed": seed, + "denoising": denoising, + "guidance": guidance, + }, + } + st.json(input_dict) + + inputs = dacite.from_dict(InferenceInput, input_dict) + + # TODO fix + init_image_path = Path(__file__).parent.parent.parent.parent / "seed_images" / "og_beat.png" + init_image = Image.open(str(init_image_path)).convert("RGB") + + # Execute the model to get the spectrogram image + image = pipeline.riffuse( + inputs, + init_image=init_image, + mask_image=None, + ) + st.image(image) + + # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained + params = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + ) + + # Reconstruct audio from the image + # TODO(hayk): It may help performance to cache this object + converter = SpectrogramImageConverter(params=params, device=str(pipeline.device)) + segment = converter.audio_from_spectrogram_image( + image, + apply_filters=True, + ) + + mp3_bytes = io.BytesIO() + segment.export(mp3_bytes, format="mp3") + mp3_bytes.seek(0) + st.audio(mp3_bytes) + + +if __name__ == "__main__": + render_interpolation_demo() diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py new file mode 100644 index 0000000..9e2c129 --- /dev/null +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -0,0 +1,130 @@ +import io +from pathlib import Path + +import dacite +from diffusers import StableDiffusionPipeline +import streamlit as st +import torch +from PIL import Image + +from riffusion.datatypes import InferenceInput +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.streamlit import util as streamlit_util + + +@st.experimental_singleton +def load_stable_diffusion_pipeline( + checkpoint: str = "riffusion/riffusion-model-v1", + device: str = "cuda", + dtype: torch.dtype = torch.float16, +) -> StableDiffusionPipeline: + """ + Load the riffusion pipeline. + """ + if device == "cpu" or device.lower().startswith("mps"): + print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") + dtype = torch.float32 + + return StableDiffusionPipeline.from_pretrained( + checkpoint, + revision="main", + torch_dtype=dtype, + safety_checker=lambda images, **kwargs: (images, False), + ).to(device) + + +@st.experimental_memo +def run_txt2img( + prompt: str, + num_inference_steps: int, + guidance: float, + negative_prompt: str, + seed: int, + width: int, + height: int, + device: str = "cuda", +) -> Image.Image: + """ + Run the text to image pipeline with caching. + """ + pipeline = load_stable_diffusion_pipeline(device=device) + + generator = torch.Generator(device="cpu").manual_seed(seed) + + output = pipeline( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance, + negative_prompt=negative_prompt or None, + generator=generator, + width=width, + height=height, + ) + + return output["images"][0] + + +def render_text_to_audio() -> None: + """ + Render audio from text. + """ + prompt = st.text_input("Prompt") + if not prompt: + st.info("Enter a prompt") + return + + negative_prompt = st.text_input("Negative prompt") + seed = st.sidebar.number_input("Seed", value=42) + num_inference_steps = st.sidebar.number_input("Inference steps", value=20) + width = st.sidebar.number_input("Width", value=512) + height = st.sidebar.number_input("Height", value=512) + guidance = st.sidebar.number_input( + "Guidance", value=7.0, help="How much the model listens to the text prompt" + ) + + default_device = "cpu" + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.backends.mps.is_available(): + default_device = "mps" + + device_options = ["cuda", "cpu", "mps"] + device = st.sidebar.selectbox( + "Device", options=device_options, index=device_options.index(default_device) + ) + assert device is not None + + image = run_txt2img( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance=guidance, + negative_prompt=negative_prompt, + seed=seed, + width=width, + height=height, + device=device, + ) + + st.image(image) + + # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained + params = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + ) + + segment = streamlit_util.audio_from_spectrogram_image( + image=image, + params=params, + device=device, + ) + + mp3_bytes = io.BytesIO() + segment.export(mp3_bytes, format="mp3") + mp3_bytes.seek(0) + st.audio(mp3_bytes) + + +if __name__ == "__main__": + render_text_to_audio() diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py new file mode 100644 index 0000000..e0ad5a2 --- /dev/null +++ b/riffusion/streamlit/util.py @@ -0,0 +1,73 @@ +""" +Streamlit utilities (mostly cached wrappers around riffusion code). +""" + +import pydub +import streamlit as st +from PIL import Image + +from riffusion.riffusion_pipeline import RiffusionPipeline +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams + + +@st.experimental_singleton +def load_riffusion_checkpoint( + checkpoint: str = "riffusion/riffusion-model-v1", + no_traced_unet: bool = False, + device: str = "cuda", +) -> RiffusionPipeline: + """ + Load the riffusion pipeline. + """ + return RiffusionPipeline.load_checkpoint( + checkpoint=checkpoint, + use_traced_unet=not no_traced_unet, + device=device, + ) + +# class CachedSpectrogramImageConverter: + +# def __init__(self, params: SpectrogramParams, device: str = "cuda"): +# self.p = params +# self.device = device +# self.converter = self._converter(params, device) + +# @staticmethod +# @st.experimental_singleton +# def _converter(params: SpectrogramParams, device: str) -> SpectrogramImageConverter: +# return SpectrogramImageConverter(params=params, device=device) + +# def audio_from_spectrogram_image( +# self, +# image: Image.Image +# ) -> pydub.AudioSegment: +# return self._converter.audio_from_spectrogram_image(image) + + +@st.experimental_singleton +def spectrogram_image_converter( + params: SpectrogramParams, + device: str = "cuda", +) -> SpectrogramImageConverter: + return SpectrogramImageConverter(params=params, device=device) + + +@st.experimental_memo +def audio_from_spectrogram_image( + image: Image.Image, + params: SpectrogramParams, + device: str = "cuda", +) -> pydub.AudioSegment: + converter = spectrogram_image_converter(params=params, device=device) + return converter.audio_from_spectrogram_image(image) + + +# @st.experimental_memo +# def spectrogram_image_from_audio( +# segment: pydub.AudioSegment, +# params: SpectrogramParams, +# device: str = "cuda", +# ) -> pydub.AudioSegment: +# converter = spectrogram_image_converter(params=params, device=device) +# return converter.spectrogram_image_from_audio(segment)