Upgrade playground app to Streamlit 1.18+
The first change was using the new non-experimental cache decorators, but then I decided to refactor to get rid of using the streamlit pages feature and instead have my own dropdown. This allows for more control to fix a page layout issue that popped up with this version.
This commit is contained in:
parent
5a989fff9c
commit
a0f12d80e2
|
@ -13,7 +13,7 @@ pysoundfile
|
|||
scipy
|
||||
soundfile
|
||||
sox
|
||||
streamlit>=1.10.0
|
||||
streamlit>=1.18.0
|
||||
torch
|
||||
torchaudio
|
||||
torchvision
|
||||
|
|
|
@ -1,43 +1,29 @@
|
|||
import streamlit as st
|
||||
|
||||
|
||||
def render_main():
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.title(":guitar: Riffusion Playground")
|
||||
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left:
|
||||
create_link(":pencil2: Text to Audio", "/text_to_audio")
|
||||
st.write("Generate audio clips from text prompts.")
|
||||
|
||||
create_link(":wave: Audio to Audio", "/audio_to_audio")
|
||||
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
||||
|
||||
create_link(":performing_arts: Interpolation", "/interpolation")
|
||||
st.write("Interpolate between prompts in the latent space.")
|
||||
|
||||
create_link(":scissors: Audio Splitter", "/split_audio")
|
||||
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
||||
|
||||
with right:
|
||||
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
||||
st.write("Generate audio in batch from a JSON file of text prompts.")
|
||||
|
||||
create_link(":paperclip: Sample Clips", "/sample_clips")
|
||||
st.write("Export short clips from an audio file.")
|
||||
|
||||
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
|
||||
st.write("Reconstruct audio from spectrogram images.")
|
||||
PAGES = {
|
||||
"🎛️ Home": "tasks.home",
|
||||
"🌊 Text to Audio": "tasks.text_to_audio",
|
||||
"✨ Audio to Audio": "tasks.audio_to_audio",
|
||||
"🎭 Interpolation": "tasks.interpolation",
|
||||
"✂️ Audio Splitter": "tasks.split_audio",
|
||||
"📜 Text to Audio Batch": "tasks.text_to_audio_batch",
|
||||
"📎 Sample Clips": "tasks.sample_clips",
|
||||
"⏈ Spectrogram to Audio": "tasks.image_to_audio",
|
||||
}
|
||||
|
||||
|
||||
def create_link(name: str, url: str) -> None:
|
||||
st.markdown(
|
||||
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
|
||||
unsafe_allow_html=True,
|
||||
def main() -> None:
|
||||
st.set_page_config(
|
||||
page_title="Riffusion Playground",
|
||||
page_icon="🎸",
|
||||
layout="wide",
|
||||
)
|
||||
|
||||
page = st.sidebar.selectbox("Page", list(PAGES.keys()))
|
||||
assert page is not None
|
||||
module = __import__(PAGES[page], fromlist=["render"])
|
||||
module.render()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_main()
|
||||
main()
|
||||
|
|
|
@ -10,14 +10,12 @@ from PIL import Image
|
|||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
|
||||
from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation
|
||||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def render_audio_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":wave: Audio to Audio")
|
||||
def render() -> None:
|
||||
st.subheader("✨ Audio to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Modify existing audio from a text prompt or interpolate between two.
|
||||
|
@ -408,7 +406,3 @@ def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
|
|||
closest_width = int(np.ceil(image.width / 32) * 32)
|
||||
closest_height = int(np.ceil(image.height / 32) * 32)
|
||||
return image.resize((closest_width, closest_height), Image.BICUBIC)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_audio_to_audio()
|
|
@ -0,0 +1,32 @@
|
|||
import streamlit as st
|
||||
|
||||
|
||||
def render():
|
||||
st.title("✨🎸 Riffusion Playground 🎸✨")
|
||||
|
||||
st.write("Select a task from the sidebar to get started!")
|
||||
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left:
|
||||
st.subheader("🌊 Text to Audio")
|
||||
st.write("Generate audio clips from text prompts.")
|
||||
|
||||
st.subheader("✨ Audio to Audio")
|
||||
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
||||
|
||||
st.subheader("🎭 Interpolation")
|
||||
st.write("Interpolate between prompts in the latent space.")
|
||||
|
||||
st.subheader("✂️ Audio Splitter")
|
||||
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
||||
|
||||
with right:
|
||||
st.subheader("📜 Text to Audio Batch")
|
||||
st.write("Generate audio in batch from a JSON file of text prompts.")
|
||||
|
||||
st.subheader("📎 Sample Clips")
|
||||
st.write("Export short clips from an audio file.")
|
||||
|
||||
st.subheader("⏈ Spectrogram to Audio")
|
||||
st.write("Reconstruct audio from spectrogram images.")
|
|
@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util
|
|||
from riffusion.util.image_util import exif_from_image
|
||||
|
||||
|
||||
def render_image_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":musical_keyboard: Image to Audio")
|
||||
def render() -> None:
|
||||
st.subheader("⏈ Image to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Reconstruct audio from spectrogram images.
|
||||
|
@ -77,7 +75,3 @@ def render_image_to_audio() -> None:
|
|||
name=Path(image_file.name).stem,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_image_to_audio()
|
|
@ -13,10 +13,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
|||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_interpolation() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":performing_arts: Interpolation")
|
||||
def render() -> None:
|
||||
st.subheader("🎭 Interpolation")
|
||||
st.write(
|
||||
"""
|
||||
Interpolate between prompts in the latent space.
|
||||
|
@ -241,7 +239,7 @@ def get_prompt_inputs(
|
|||
return p
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
@st.cache_data
|
||||
def run_interpolation(
|
||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||
|
@ -275,7 +273,3 @@ def run_interpolation(
|
|||
)
|
||||
|
||||
return image, audio_bytes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_interpolation()
|
|
@ -10,10 +10,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
|||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_sample_clips() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":paperclip: Sample Clips")
|
||||
def render() -> None:
|
||||
st.subheader("📎 Sample Clips")
|
||||
st.write(
|
||||
"""
|
||||
Export short clips from an audio file.
|
||||
|
@ -125,7 +123,3 @@ def render_sample_clips() -> None:
|
|||
|
||||
if save_to_disk:
|
||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_sample_clips()
|
|
@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util
|
|||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def render_split_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":scissors: Audio Splitter")
|
||||
def render() -> None:
|
||||
st.subheader("✂️ Audio Splitter")
|
||||
st.write(
|
||||
"""
|
||||
Split audio into individual instrument stems.
|
||||
|
@ -99,7 +97,3 @@ def split_audio_cached(
|
|||
segment: pydub.AudioSegment, device: str = "cuda"
|
||||
) -> T.Dict[str, pydub.AudioSegment]:
|
||||
return split_audio(segment, device=device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_split_audio()
|
|
@ -6,10 +6,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
|||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_text_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":pencil2: Text to Audio")
|
||||
def render() -> None:
|
||||
st.subheader("🌊 Text to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Generate audio from text prompts.
|
||||
|
@ -119,7 +117,3 @@ def render_text_to_audio() -> None:
|
|||
)
|
||||
|
||||
seed += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio()
|
|
@ -14,7 +14,7 @@ EXAMPLE_INPUT = """
|
|||
"seed": 42,
|
||||
"num_inference_steps": 50,
|
||||
"guidance": 7.0,
|
||||
"width": 512,
|
||||
"width": 512
|
||||
},
|
||||
"entries": [
|
||||
{
|
||||
|
@ -32,10 +32,8 @@ EXAMPLE_INPUT = """
|
|||
"""
|
||||
|
||||
|
||||
def render_text_to_audio_batch() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":scroll: Text to Audio Batch")
|
||||
def render() -> None:
|
||||
st.subheader("📜 Text to Audio Batch")
|
||||
st.write(
|
||||
"""
|
||||
Generate audio in batch from a JSON file of text prompts.
|
||||
|
@ -141,7 +139,3 @@ def render_text_to_audio_batch() -> None:
|
|||
st.info(f"Output written to {str(output_path)}")
|
||||
else:
|
||||
st.info("Enter output directory in sidebar to save to disk")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio_batch()
|
|
@ -33,7 +33,7 @@ SCHEDULER_OPTIONS = [
|
|||
]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def load_riffusion_checkpoint(
|
||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||
no_traced_unet: bool = False,
|
||||
|
@ -49,7 +49,7 @@ def load_riffusion_checkpoint(
|
|||
)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def load_stable_diffusion_pipeline(
|
||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||
device: str = "cuda",
|
||||
|
@ -109,7 +109,7 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
|
|||
raise ValueError(f"Unknown scheduler {scheduler}")
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def pipeline_lock() -> threading.Lock:
|
||||
"""
|
||||
Singleton lock used to prevent concurrent access to any model pipeline.
|
||||
|
@ -117,7 +117,7 @@ def pipeline_lock() -> threading.Lock:
|
|||
return threading.Lock()
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def load_stable_diffusion_img2img_pipeline(
|
||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||
device: str = "cuda",
|
||||
|
@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
|
|||
return pipeline
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
@st.cache_data
|
||||
def run_txt2img(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
|
@ -184,7 +184,7 @@ def run_txt2img(
|
|||
return output["images"][0]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def spectrogram_image_converter(
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
|
@ -202,7 +202,7 @@ def spectrogram_image_from_audio(
|
|||
return converter.spectrogram_image_from_audio(segment)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
@st.cache_data
|
||||
def audio_segment_from_spectrogram_image(
|
||||
image: Image.Image,
|
||||
params: SpectrogramParams,
|
||||
|
@ -212,7 +212,7 @@ def audio_segment_from_spectrogram_image(
|
|||
return converter.audio_from_spectrogram_image(image)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
@st.cache_data
|
||||
def audio_bytes_from_spectrogram_image(
|
||||
image: Image.Image,
|
||||
params: SpectrogramParams,
|
||||
|
@ -289,17 +289,17 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
|
|||
return custom_checkpoint or DEFAULT_CHECKPOINT
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
@st.cache_data
|
||||
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
||||
return pydub.AudioSegment.from_file(audio_file)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def get_audio_splitter(device: str = "cuda"):
|
||||
return AudioSplitter(device=device)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
@st.cache_resource
|
||||
def load_magic_mix_pipeline(
|
||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||
device: str = "cuda",
|
||||
|
|
Loading…
Reference in New Issue