From a0f12d80e2115c45c6810e4c13673bc8b367be3a Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Sun, 26 Mar 2023 22:54:28 +0000 Subject: [PATCH] Upgrade playground app to Streamlit 1.18+ The first change was using the new non-experimental cache decorators, but then I decided to refactor to get rid of using the streamlit pages feature and instead have my own dropdown. This allows for more control to fix a page layout issue that popped up with this version. --- VERSION | 2 +- requirements.txt | 2 +- riffusion/streamlit/playground.py | 56 +++++++------------ riffusion/streamlit/tasks/__init__.py | 0 .../{pages => tasks}/audio_to_audio.py | 12 +--- riffusion/streamlit/tasks/home.py | 32 +++++++++++ .../{pages => tasks}/image_to_audio.py | 10 +--- .../{pages => tasks}/interpolation.py | 12 +--- .../{pages => tasks}/sample_clips.py | 10 +--- .../streamlit/{pages => tasks}/split_audio.py | 10 +--- .../{pages => tasks}/text_to_audio.py | 10 +--- .../{pages => tasks}/text_to_audio_batch.py | 12 +--- riffusion/streamlit/util.py | 22 ++++---- 13 files changed, 83 insertions(+), 107 deletions(-) create mode 100644 riffusion/streamlit/tasks/__init__.py rename riffusion/streamlit/{pages => tasks}/audio_to_audio.py (98%) create mode 100644 riffusion/streamlit/tasks/home.py rename riffusion/streamlit/{pages => tasks}/image_to_audio.py (91%) rename riffusion/streamlit/{pages => tasks}/interpolation.py (97%) rename riffusion/streamlit/{pages => tasks}/sample_clips.py (95%) rename riffusion/streamlit/{pages => tasks}/split_audio.py (94%) rename riffusion/streamlit/{pages => tasks}/text_to_audio.py (94%) rename riffusion/streamlit/{pages => tasks}/text_to_audio_batch.py (94%) diff --git a/VERSION b/VERSION index a2268e2..9e11b32 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.1 \ No newline at end of file +0.3.1 diff --git a/requirements.txt b/requirements.txt index 38e7a7b..4221960 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ pysoundfile scipy soundfile sox -streamlit>=1.10.0 +streamlit>=1.18.0 torch torchaudio torchvision diff --git a/riffusion/streamlit/playground.py b/riffusion/streamlit/playground.py index f705803..1104159 100644 --- a/riffusion/streamlit/playground.py +++ b/riffusion/streamlit/playground.py @@ -1,43 +1,29 @@ import streamlit as st - -def render_main(): - st.set_page_config(layout="wide", page_icon="🎸") - - st.title(":guitar: Riffusion Playground") - - left, right = st.columns(2) - - with left: - create_link(":pencil2: Text to Audio", "/text_to_audio") - st.write("Generate audio clips from text prompts.") - - create_link(":wave: Audio to Audio", "/audio_to_audio") - st.write("Upload audio and modify with text prompt (interpolation supported).") - - create_link(":performing_arts: Interpolation", "/interpolation") - st.write("Interpolate between prompts in the latent space.") - - create_link(":scissors: Audio Splitter", "/split_audio") - st.write("Split audio into stems like vocals, bass, drums, guitar, etc.") - - with right: - create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch") - st.write("Generate audio in batch from a JSON file of text prompts.") - - create_link(":paperclip: Sample Clips", "/sample_clips") - st.write("Export short clips from an audio file.") - - create_link(":musical_keyboard: Image to Audio", "/image_to_audio") - st.write("Reconstruct audio from spectrogram images.") +PAGES = { + "πŸŽ›οΈ Home": "tasks.home", + "🌊 Text to Audio": "tasks.text_to_audio", + "✨ Audio to Audio": "tasks.audio_to_audio", + "🎭 Interpolation": "tasks.interpolation", + "βœ‚οΈ Audio Splitter": "tasks.split_audio", + "πŸ“œ Text to Audio Batch": "tasks.text_to_audio_batch", + "πŸ“Ž Sample Clips": "tasks.sample_clips", + "⏈ Spectrogram to Audio": "tasks.image_to_audio", +} -def create_link(name: str, url: str) -> None: - st.markdown( - f"### {name}", - unsafe_allow_html=True, +def main() -> None: + st.set_page_config( + page_title="Riffusion Playground", + page_icon="🎸", + layout="wide", ) + page = st.sidebar.selectbox("Page", list(PAGES.keys())) + assert page is not None + module = __import__(PAGES[page], fromlist=["render"]) + module.render() + if __name__ == "__main__": - render_main() + main() diff --git a/riffusion/streamlit/tasks/__init__.py b/riffusion/streamlit/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/riffusion/streamlit/pages/audio_to_audio.py b/riffusion/streamlit/tasks/audio_to_audio.py similarity index 98% rename from riffusion/streamlit/pages/audio_to_audio.py rename to riffusion/streamlit/tasks/audio_to_audio.py index 9240c0e..ee866c0 100644 --- a/riffusion/streamlit/pages/audio_to_audio.py +++ b/riffusion/streamlit/tasks/audio_to_audio.py @@ -10,14 +10,12 @@ from PIL import Image from riffusion.datatypes import InferenceInput, PromptInput from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util -from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation +from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation from riffusion.util import audio_util -def render_audio_to_audio() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":wave: Audio to Audio") +def render() -> None: + st.subheader("✨ Audio to Audio") st.write( """ Modify existing audio from a text prompt or interpolate between two. @@ -408,7 +406,3 @@ def scale_image_to_32_stride(image: Image.Image) -> Image.Image: closest_width = int(np.ceil(image.width / 32) * 32) closest_height = int(np.ceil(image.height / 32) * 32) return image.resize((closest_width, closest_height), Image.BICUBIC) - - -if __name__ == "__main__": - render_audio_to_audio() diff --git a/riffusion/streamlit/tasks/home.py b/riffusion/streamlit/tasks/home.py new file mode 100644 index 0000000..539e9e4 --- /dev/null +++ b/riffusion/streamlit/tasks/home.py @@ -0,0 +1,32 @@ +import streamlit as st + + +def render(): + st.title("✨🎸 Riffusion Playground 🎸✨") + + st.write("Select a task from the sidebar to get started!") + + left, right = st.columns(2) + + with left: + st.subheader("🌊 Text to Audio") + st.write("Generate audio clips from text prompts.") + + st.subheader("✨ Audio to Audio") + st.write("Upload audio and modify with text prompt (interpolation supported).") + + st.subheader("🎭 Interpolation") + st.write("Interpolate between prompts in the latent space.") + + st.subheader("βœ‚οΈ Audio Splitter") + st.write("Split audio into stems like vocals, bass, drums, guitar, etc.") + + with right: + st.subheader("πŸ“œ Text to Audio Batch") + st.write("Generate audio in batch from a JSON file of text prompts.") + + st.subheader("πŸ“Ž Sample Clips") + st.write("Export short clips from an audio file.") + + st.subheader("⏈ Spectrogram to Audio") + st.write("Reconstruct audio from spectrogram images.") diff --git a/riffusion/streamlit/pages/image_to_audio.py b/riffusion/streamlit/tasks/image_to_audio.py similarity index 91% rename from riffusion/streamlit/pages/image_to_audio.py rename to riffusion/streamlit/tasks/image_to_audio.py index f330cc6..d4e25d6 100644 --- a/riffusion/streamlit/pages/image_to_audio.py +++ b/riffusion/streamlit/tasks/image_to_audio.py @@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util from riffusion.util.image_util import exif_from_image -def render_image_to_audio() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":musical_keyboard: Image to Audio") +def render() -> None: + st.subheader("⏈ Image to Audio") st.write( """ Reconstruct audio from spectrogram images. @@ -77,7 +75,3 @@ def render_image_to_audio() -> None: name=Path(image_file.name).stem, extension=extension, ) - - -if __name__ == "__main__": - render_image_to_audio() diff --git a/riffusion/streamlit/pages/interpolation.py b/riffusion/streamlit/tasks/interpolation.py similarity index 97% rename from riffusion/streamlit/pages/interpolation.py rename to riffusion/streamlit/tasks/interpolation.py index d7f47c4..2bb708b 100644 --- a/riffusion/streamlit/pages/interpolation.py +++ b/riffusion/streamlit/tasks/interpolation.py @@ -13,10 +13,8 @@ from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util -def render_interpolation() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":performing_arts: Interpolation") +def render() -> None: + st.subheader("🎭 Interpolation") st.write( """ Interpolate between prompts in the latent space. @@ -241,7 +239,7 @@ def get_prompt_inputs( return p -@st.experimental_memo +@st.cache_data def run_interpolation( inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3" ) -> T.Tuple[Image.Image, io.BytesIO]: @@ -275,7 +273,3 @@ def run_interpolation( ) return image, audio_bytes - - -if __name__ == "__main__": - render_interpolation() diff --git a/riffusion/streamlit/pages/sample_clips.py b/riffusion/streamlit/tasks/sample_clips.py similarity index 95% rename from riffusion/streamlit/pages/sample_clips.py rename to riffusion/streamlit/tasks/sample_clips.py index 9acf3c1..53494d4 100644 --- a/riffusion/streamlit/pages/sample_clips.py +++ b/riffusion/streamlit/tasks/sample_clips.py @@ -10,10 +10,8 @@ from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util -def render_sample_clips() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":paperclip: Sample Clips") +def render() -> None: + st.subheader("πŸ“Ž Sample Clips") st.write( """ Export short clips from an audio file. @@ -125,7 +123,3 @@ def render_sample_clips() -> None: if save_to_disk: st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`") - - -if __name__ == "__main__": - render_sample_clips() diff --git a/riffusion/streamlit/pages/split_audio.py b/riffusion/streamlit/tasks/split_audio.py similarity index 94% rename from riffusion/streamlit/pages/split_audio.py rename to riffusion/streamlit/tasks/split_audio.py index 69720a5..c98802f 100644 --- a/riffusion/streamlit/pages/split_audio.py +++ b/riffusion/streamlit/tasks/split_audio.py @@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util from riffusion.util import audio_util -def render_split_audio() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":scissors: Audio Splitter") +def render() -> None: + st.subheader("βœ‚οΈ Audio Splitter") st.write( """ Split audio into individual instrument stems. @@ -99,7 +97,3 @@ def split_audio_cached( segment: pydub.AudioSegment, device: str = "cuda" ) -> T.Dict[str, pydub.AudioSegment]: return split_audio(segment, device=device) - - -if __name__ == "__main__": - render_split_audio() diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/tasks/text_to_audio.py similarity index 94% rename from riffusion/streamlit/pages/text_to_audio.py rename to riffusion/streamlit/tasks/text_to_audio.py index 9630ee8..c70e431 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/tasks/text_to_audio.py @@ -6,10 +6,8 @@ from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util -def render_text_to_audio() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":pencil2: Text to Audio") +def render() -> None: + st.subheader("🌊 Text to Audio") st.write( """ Generate audio from text prompts. @@ -119,7 +117,3 @@ def render_text_to_audio() -> None: ) seed += 1 - - -if __name__ == "__main__": - render_text_to_audio() diff --git a/riffusion/streamlit/pages/text_to_audio_batch.py b/riffusion/streamlit/tasks/text_to_audio_batch.py similarity index 94% rename from riffusion/streamlit/pages/text_to_audio_batch.py rename to riffusion/streamlit/tasks/text_to_audio_batch.py index 8763011..fa2fa42 100644 --- a/riffusion/streamlit/pages/text_to_audio_batch.py +++ b/riffusion/streamlit/tasks/text_to_audio_batch.py @@ -14,7 +14,7 @@ EXAMPLE_INPUT = """ "seed": 42, "num_inference_steps": 50, "guidance": 7.0, - "width": 512, + "width": 512 }, "entries": [ { @@ -32,10 +32,8 @@ EXAMPLE_INPUT = """ """ -def render_text_to_audio_batch() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":scroll: Text to Audio Batch") +def render() -> None: + st.subheader("πŸ“œ Text to Audio Batch") st.write( """ Generate audio in batch from a JSON file of text prompts. @@ -141,7 +139,3 @@ def render_text_to_audio_batch() -> None: st.info(f"Output written to {str(output_path)}") else: st.info("Enter output directory in sidebar to save to disk") - - -if __name__ == "__main__": - render_text_to_audio_batch() diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index bb51035..dfccc68 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -33,7 +33,7 @@ SCHEDULER_OPTIONS = [ ] -@st.experimental_singleton +@st.cache_resource def load_riffusion_checkpoint( checkpoint: str = DEFAULT_CHECKPOINT, no_traced_unet: bool = False, @@ -49,7 +49,7 @@ def load_riffusion_checkpoint( ) -@st.experimental_singleton +@st.cache_resource def load_stable_diffusion_pipeline( checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", @@ -109,7 +109,7 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any: raise ValueError(f"Unknown scheduler {scheduler}") -@st.experimental_singleton +@st.cache_resource def pipeline_lock() -> threading.Lock: """ Singleton lock used to prevent concurrent access to any model pipeline. @@ -117,7 +117,7 @@ def pipeline_lock() -> threading.Lock: return threading.Lock() -@st.experimental_singleton +@st.cache_resource def load_stable_diffusion_img2img_pipeline( checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", @@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline( return pipeline -@st.experimental_memo +@st.cache_data def run_txt2img( prompt: str, num_inference_steps: int, @@ -184,7 +184,7 @@ def run_txt2img( return output["images"][0] -@st.experimental_singleton +@st.cache_resource def spectrogram_image_converter( params: SpectrogramParams, device: str = "cuda", @@ -202,7 +202,7 @@ def spectrogram_image_from_audio( return converter.spectrogram_image_from_audio(segment) -@st.experimental_memo +@st.cache_data def audio_segment_from_spectrogram_image( image: Image.Image, params: SpectrogramParams, @@ -212,7 +212,7 @@ def audio_segment_from_spectrogram_image( return converter.audio_from_spectrogram_image(image) -@st.experimental_memo +@st.cache_data def audio_bytes_from_spectrogram_image( image: Image.Image, params: SpectrogramParams, @@ -289,17 +289,17 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str: return custom_checkpoint or DEFAULT_CHECKPOINT -@st.experimental_memo +@st.cache_data def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: return pydub.AudioSegment.from_file(audio_file) -@st.experimental_singleton +@st.cache_resource def get_audio_splitter(device: str = "cuda"): return AudioSplitter(device=device) -@st.experimental_singleton +@st.cache_resource def load_magic_mix_pipeline( checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda",