Upgrade playground app to Streamlit 1.18+

The first change was using the new non-experimental cache decorators,
but then I decided to refactor to get rid of using the streamlit pages
feature and instead have my own dropdown. This allows for more control
to fix a page layout issue that popped up with this version.
This commit is contained in:
Hayk Martiros 2023-03-26 22:54:28 +00:00
parent 5a989fff9c
commit a0f12d80e2
13 changed files with 83 additions and 107 deletions

View File

@ -13,7 +13,7 @@ pysoundfile
scipy scipy
soundfile soundfile
sox sox
streamlit>=1.10.0 streamlit>=1.18.0
torch torch
torchaudio torchaudio
torchvision torchvision

View File

@ -1,43 +1,29 @@
import streamlit as st import streamlit as st
PAGES = {
def render_main(): "🎛️ Home": "tasks.home",
st.set_page_config(layout="wide", page_icon="🎸") "🌊 Text to Audio": "tasks.text_to_audio",
"✨ Audio to Audio": "tasks.audio_to_audio",
st.title(":guitar: Riffusion Playground") "🎭 Interpolation": "tasks.interpolation",
"✂️ Audio Splitter": "tasks.split_audio",
left, right = st.columns(2) "📜 Text to Audio Batch": "tasks.text_to_audio_batch",
"📎 Sample Clips": "tasks.sample_clips",
with left: "⏈ Spectrogram to Audio": "tasks.image_to_audio",
create_link(":pencil2: Text to Audio", "/text_to_audio") }
st.write("Generate audio clips from text prompts.")
create_link(":wave: Audio to Audio", "/audio_to_audio")
st.write("Upload audio and modify with text prompt (interpolation supported).")
create_link(":performing_arts: Interpolation", "/interpolation")
st.write("Interpolate between prompts in the latent space.")
create_link(":scissors: Audio Splitter", "/split_audio")
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
with right:
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
st.write("Generate audio in batch from a JSON file of text prompts.")
create_link(":paperclip: Sample Clips", "/sample_clips")
st.write("Export short clips from an audio file.")
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
st.write("Reconstruct audio from spectrogram images.")
def create_link(name: str, url: str) -> None: def main() -> None:
st.markdown( st.set_page_config(
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>", page_title="Riffusion Playground",
unsafe_allow_html=True, page_icon="🎸",
layout="wide",
) )
page = st.sidebar.selectbox("Page", list(PAGES.keys()))
assert page is not None
module = __import__(PAGES[page], fromlist=["render"])
module.render()
if __name__ == "__main__": if __name__ == "__main__":
render_main() main()

View File

View File

@ -10,14 +10,12 @@ from PIL import Image
from riffusion.datatypes import InferenceInput, PromptInput from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util from riffusion.streamlit import util as streamlit_util
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation
from riffusion.util import audio_util from riffusion.util import audio_util
def render_audio_to_audio() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("✨ Audio to Audio")
st.subheader(":wave: Audio to Audio")
st.write( st.write(
""" """
Modify existing audio from a text prompt or interpolate between two. Modify existing audio from a text prompt or interpolate between two.
@ -408,7 +406,3 @@ def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
closest_width = int(np.ceil(image.width / 32) * 32) closest_width = int(np.ceil(image.width / 32) * 32)
closest_height = int(np.ceil(image.height / 32) * 32) closest_height = int(np.ceil(image.height / 32) * 32)
return image.resize((closest_width, closest_height), Image.BICUBIC) return image.resize((closest_width, closest_height), Image.BICUBIC)
if __name__ == "__main__":
render_audio_to_audio()

View File

@ -0,0 +1,32 @@
import streamlit as st
def render():
st.title("✨🎸 Riffusion Playground 🎸✨")
st.write("Select a task from the sidebar to get started!")
left, right = st.columns(2)
with left:
st.subheader("🌊 Text to Audio")
st.write("Generate audio clips from text prompts.")
st.subheader("✨ Audio to Audio")
st.write("Upload audio and modify with text prompt (interpolation supported).")
st.subheader("🎭 Interpolation")
st.write("Interpolate between prompts in the latent space.")
st.subheader("✂️ Audio Splitter")
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
with right:
st.subheader("📜 Text to Audio Batch")
st.write("Generate audio in batch from a JSON file of text prompts.")
st.subheader("📎 Sample Clips")
st.write("Export short clips from an audio file.")
st.subheader("⏈ Spectrogram to Audio")
st.write("Reconstruct audio from spectrogram images.")

View File

@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util
from riffusion.util.image_util import exif_from_image from riffusion.util.image_util import exif_from_image
def render_image_to_audio() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("⏈ Image to Audio")
st.subheader(":musical_keyboard: Image to Audio")
st.write( st.write(
""" """
Reconstruct audio from spectrogram images. Reconstruct audio from spectrogram images.
@ -77,7 +75,3 @@ def render_image_to_audio() -> None:
name=Path(image_file.name).stem, name=Path(image_file.name).stem,
extension=extension, extension=extension,
) )
if __name__ == "__main__":
render_image_to_audio()

View File

@ -13,10 +13,8 @@ from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util from riffusion.streamlit import util as streamlit_util
def render_interpolation() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("🎭 Interpolation")
st.subheader(":performing_arts: Interpolation")
st.write( st.write(
""" """
Interpolate between prompts in the latent space. Interpolate between prompts in the latent space.
@ -241,7 +239,7 @@ def get_prompt_inputs(
return p return p
@st.experimental_memo @st.cache_data
def run_interpolation( def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3" inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
) -> T.Tuple[Image.Image, io.BytesIO]: ) -> T.Tuple[Image.Image, io.BytesIO]:
@ -275,7 +273,3 @@ def run_interpolation(
) )
return image, audio_bytes return image, audio_bytes
if __name__ == "__main__":
render_interpolation()

View File

@ -10,10 +10,8 @@ from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util from riffusion.streamlit import util as streamlit_util
def render_sample_clips() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("📎 Sample Clips")
st.subheader(":paperclip: Sample Clips")
st.write( st.write(
""" """
Export short clips from an audio file. Export short clips from an audio file.
@ -125,7 +123,3 @@ def render_sample_clips() -> None:
if save_to_disk: if save_to_disk:
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`") st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
if __name__ == "__main__":
render_sample_clips()

View File

@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util
from riffusion.util import audio_util from riffusion.util import audio_util
def render_split_audio() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("✂️ Audio Splitter")
st.subheader(":scissors: Audio Splitter")
st.write( st.write(
""" """
Split audio into individual instrument stems. Split audio into individual instrument stems.
@ -99,7 +97,3 @@ def split_audio_cached(
segment: pydub.AudioSegment, device: str = "cuda" segment: pydub.AudioSegment, device: str = "cuda"
) -> T.Dict[str, pydub.AudioSegment]: ) -> T.Dict[str, pydub.AudioSegment]:
return split_audio(segment, device=device) return split_audio(segment, device=device)
if __name__ == "__main__":
render_split_audio()

View File

@ -6,10 +6,8 @@ from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util from riffusion.streamlit import util as streamlit_util
def render_text_to_audio() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("🌊 Text to Audio")
st.subheader(":pencil2: Text to Audio")
st.write( st.write(
""" """
Generate audio from text prompts. Generate audio from text prompts.
@ -119,7 +117,3 @@ def render_text_to_audio() -> None:
) )
seed += 1 seed += 1
if __name__ == "__main__":
render_text_to_audio()

View File

@ -14,7 +14,7 @@ EXAMPLE_INPUT = """
"seed": 42, "seed": 42,
"num_inference_steps": 50, "num_inference_steps": 50,
"guidance": 7.0, "guidance": 7.0,
"width": 512, "width": 512
}, },
"entries": [ "entries": [
{ {
@ -32,10 +32,8 @@ EXAMPLE_INPUT = """
""" """
def render_text_to_audio_batch() -> None: def render() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.subheader("📜 Text to Audio Batch")
st.subheader(":scroll: Text to Audio Batch")
st.write( st.write(
""" """
Generate audio in batch from a JSON file of text prompts. Generate audio in batch from a JSON file of text prompts.
@ -141,7 +139,3 @@ def render_text_to_audio_batch() -> None:
st.info(f"Output written to {str(output_path)}") st.info(f"Output written to {str(output_path)}")
else: else:
st.info("Enter output directory in sidebar to save to disk") st.info("Enter output directory in sidebar to save to disk")
if __name__ == "__main__":
render_text_to_audio_batch()

View File

@ -33,7 +33,7 @@ SCHEDULER_OPTIONS = [
] ]
@st.experimental_singleton @st.cache_resource
def load_riffusion_checkpoint( def load_riffusion_checkpoint(
checkpoint: str = DEFAULT_CHECKPOINT, checkpoint: str = DEFAULT_CHECKPOINT,
no_traced_unet: bool = False, no_traced_unet: bool = False,
@ -49,7 +49,7 @@ def load_riffusion_checkpoint(
) )
@st.experimental_singleton @st.cache_resource
def load_stable_diffusion_pipeline( def load_stable_diffusion_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT, checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
@ -109,7 +109,7 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
raise ValueError(f"Unknown scheduler {scheduler}") raise ValueError(f"Unknown scheduler {scheduler}")
@st.experimental_singleton @st.cache_resource
def pipeline_lock() -> threading.Lock: def pipeline_lock() -> threading.Lock:
""" """
Singleton lock used to prevent concurrent access to any model pipeline. Singleton lock used to prevent concurrent access to any model pipeline.
@ -117,7 +117,7 @@ def pipeline_lock() -> threading.Lock:
return threading.Lock() return threading.Lock()
@st.experimental_singleton @st.cache_resource
def load_stable_diffusion_img2img_pipeline( def load_stable_diffusion_img2img_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT, checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
return pipeline return pipeline
@st.experimental_memo @st.cache_data
def run_txt2img( def run_txt2img(
prompt: str, prompt: str,
num_inference_steps: int, num_inference_steps: int,
@ -184,7 +184,7 @@ def run_txt2img(
return output["images"][0] return output["images"][0]
@st.experimental_singleton @st.cache_resource
def spectrogram_image_converter( def spectrogram_image_converter(
params: SpectrogramParams, params: SpectrogramParams,
device: str = "cuda", device: str = "cuda",
@ -202,7 +202,7 @@ def spectrogram_image_from_audio(
return converter.spectrogram_image_from_audio(segment) return converter.spectrogram_image_from_audio(segment)
@st.experimental_memo @st.cache_data
def audio_segment_from_spectrogram_image( def audio_segment_from_spectrogram_image(
image: Image.Image, image: Image.Image,
params: SpectrogramParams, params: SpectrogramParams,
@ -212,7 +212,7 @@ def audio_segment_from_spectrogram_image(
return converter.audio_from_spectrogram_image(image) return converter.audio_from_spectrogram_image(image)
@st.experimental_memo @st.cache_data
def audio_bytes_from_spectrogram_image( def audio_bytes_from_spectrogram_image(
image: Image.Image, image: Image.Image,
params: SpectrogramParams, params: SpectrogramParams,
@ -289,17 +289,17 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
return custom_checkpoint or DEFAULT_CHECKPOINT return custom_checkpoint or DEFAULT_CHECKPOINT
@st.experimental_memo @st.cache_data
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file) return pydub.AudioSegment.from_file(audio_file)
@st.experimental_singleton @st.cache_resource
def get_audio_splitter(device: str = "cuda"): def get_audio_splitter(device: str = "cuda"):
return AudioSplitter(device=device) return AudioSplitter(device=device)
@st.experimental_singleton @st.cache_resource
def load_magic_mix_pipeline( def load_magic_mix_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT, checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",