Upgrade playground app to Streamlit 1.18+
The first change was using the new non-experimental cache decorators, but then I decided to refactor to get rid of using the streamlit pages feature and instead have my own dropdown. This allows for more control to fix a page layout issue that popped up with this version.
This commit is contained in:
parent
5a989fff9c
commit
a0f12d80e2
|
@ -13,7 +13,7 @@ pysoundfile
|
||||||
scipy
|
scipy
|
||||||
soundfile
|
soundfile
|
||||||
sox
|
sox
|
||||||
streamlit>=1.10.0
|
streamlit>=1.18.0
|
||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
torchvision
|
torchvision
|
||||||
|
|
|
@ -1,43 +1,29 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
PAGES = {
|
||||||
def render_main():
|
"🎛️ Home": "tasks.home",
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
"🌊 Text to Audio": "tasks.text_to_audio",
|
||||||
|
"✨ Audio to Audio": "tasks.audio_to_audio",
|
||||||
st.title(":guitar: Riffusion Playground")
|
"🎭 Interpolation": "tasks.interpolation",
|
||||||
|
"✂️ Audio Splitter": "tasks.split_audio",
|
||||||
left, right = st.columns(2)
|
"📜 Text to Audio Batch": "tasks.text_to_audio_batch",
|
||||||
|
"📎 Sample Clips": "tasks.sample_clips",
|
||||||
with left:
|
"⏈ Spectrogram to Audio": "tasks.image_to_audio",
|
||||||
create_link(":pencil2: Text to Audio", "/text_to_audio")
|
}
|
||||||
st.write("Generate audio clips from text prompts.")
|
|
||||||
|
|
||||||
create_link(":wave: Audio to Audio", "/audio_to_audio")
|
|
||||||
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
|
||||||
|
|
||||||
create_link(":performing_arts: Interpolation", "/interpolation")
|
|
||||||
st.write("Interpolate between prompts in the latent space.")
|
|
||||||
|
|
||||||
create_link(":scissors: Audio Splitter", "/split_audio")
|
|
||||||
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
|
||||||
|
|
||||||
with right:
|
|
||||||
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
|
||||||
st.write("Generate audio in batch from a JSON file of text prompts.")
|
|
||||||
|
|
||||||
create_link(":paperclip: Sample Clips", "/sample_clips")
|
|
||||||
st.write("Export short clips from an audio file.")
|
|
||||||
|
|
||||||
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
|
|
||||||
st.write("Reconstruct audio from spectrogram images.")
|
|
||||||
|
|
||||||
|
|
||||||
def create_link(name: str, url: str) -> None:
|
def main() -> None:
|
||||||
st.markdown(
|
st.set_page_config(
|
||||||
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
|
page_title="Riffusion Playground",
|
||||||
unsafe_allow_html=True,
|
page_icon="🎸",
|
||||||
|
layout="wide",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
page = st.sidebar.selectbox("Page", list(PAGES.keys()))
|
||||||
|
assert page is not None
|
||||||
|
module = __import__(PAGES[page], fromlist=["render"])
|
||||||
|
module.render()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
render_main()
|
main()
|
||||||
|
|
|
@ -10,14 +10,12 @@ from PIL import Image
|
||||||
from riffusion.datatypes import InferenceInput, PromptInput
|
from riffusion.datatypes import InferenceInput, PromptInput
|
||||||
from riffusion.spectrogram_params import SpectrogramParams
|
from riffusion.spectrogram_params import SpectrogramParams
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
|
from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation
|
||||||
from riffusion.util import audio_util
|
from riffusion.util import audio_util
|
||||||
|
|
||||||
|
|
||||||
def render_audio_to_audio() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("✨ Audio to Audio")
|
||||||
|
|
||||||
st.subheader(":wave: Audio to Audio")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Modify existing audio from a text prompt or interpolate between two.
|
Modify existing audio from a text prompt or interpolate between two.
|
||||||
|
@ -408,7 +406,3 @@ def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
|
||||||
closest_width = int(np.ceil(image.width / 32) * 32)
|
closest_width = int(np.ceil(image.width / 32) * 32)
|
||||||
closest_height = int(np.ceil(image.height / 32) * 32)
|
closest_height = int(np.ceil(image.height / 32) * 32)
|
||||||
return image.resize((closest_width, closest_height), Image.BICUBIC)
|
return image.resize((closest_width, closest_height), Image.BICUBIC)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_audio_to_audio()
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
|
||||||
|
def render():
|
||||||
|
st.title("✨🎸 Riffusion Playground 🎸✨")
|
||||||
|
|
||||||
|
st.write("Select a task from the sidebar to get started!")
|
||||||
|
|
||||||
|
left, right = st.columns(2)
|
||||||
|
|
||||||
|
with left:
|
||||||
|
st.subheader("🌊 Text to Audio")
|
||||||
|
st.write("Generate audio clips from text prompts.")
|
||||||
|
|
||||||
|
st.subheader("✨ Audio to Audio")
|
||||||
|
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
||||||
|
|
||||||
|
st.subheader("🎭 Interpolation")
|
||||||
|
st.write("Interpolate between prompts in the latent space.")
|
||||||
|
|
||||||
|
st.subheader("✂️ Audio Splitter")
|
||||||
|
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
||||||
|
|
||||||
|
with right:
|
||||||
|
st.subheader("📜 Text to Audio Batch")
|
||||||
|
st.write("Generate audio in batch from a JSON file of text prompts.")
|
||||||
|
|
||||||
|
st.subheader("📎 Sample Clips")
|
||||||
|
st.write("Export short clips from an audio file.")
|
||||||
|
|
||||||
|
st.subheader("⏈ Spectrogram to Audio")
|
||||||
|
st.write("Reconstruct audio from spectrogram images.")
|
|
@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util
|
||||||
from riffusion.util.image_util import exif_from_image
|
from riffusion.util.image_util import exif_from_image
|
||||||
|
|
||||||
|
|
||||||
def render_image_to_audio() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("⏈ Image to Audio")
|
||||||
|
|
||||||
st.subheader(":musical_keyboard: Image to Audio")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Reconstruct audio from spectrogram images.
|
Reconstruct audio from spectrogram images.
|
||||||
|
@ -77,7 +75,3 @@ def render_image_to_audio() -> None:
|
||||||
name=Path(image_file.name).stem,
|
name=Path(image_file.name).stem,
|
||||||
extension=extension,
|
extension=extension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_image_to_audio()
|
|
|
@ -13,10 +13,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
|
||||||
|
|
||||||
def render_interpolation() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("🎭 Interpolation")
|
||||||
|
|
||||||
st.subheader(":performing_arts: Interpolation")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Interpolate between prompts in the latent space.
|
Interpolate between prompts in the latent space.
|
||||||
|
@ -241,7 +239,7 @@ def get_prompt_inputs(
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.cache_data
|
||||||
def run_interpolation(
|
def run_interpolation(
|
||||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
||||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||||
|
@ -275,7 +273,3 @@ def run_interpolation(
|
||||||
)
|
)
|
||||||
|
|
||||||
return image, audio_bytes
|
return image, audio_bytes
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_interpolation()
|
|
|
@ -10,10 +10,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
|
||||||
|
|
||||||
def render_sample_clips() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("📎 Sample Clips")
|
||||||
|
|
||||||
st.subheader(":paperclip: Sample Clips")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Export short clips from an audio file.
|
Export short clips from an audio file.
|
||||||
|
@ -125,7 +123,3 @@ def render_sample_clips() -> None:
|
||||||
|
|
||||||
if save_to_disk:
|
if save_to_disk:
|
||||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_sample_clips()
|
|
|
@ -9,10 +9,8 @@ from riffusion.streamlit import util as streamlit_util
|
||||||
from riffusion.util import audio_util
|
from riffusion.util import audio_util
|
||||||
|
|
||||||
|
|
||||||
def render_split_audio() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("✂️ Audio Splitter")
|
||||||
|
|
||||||
st.subheader(":scissors: Audio Splitter")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Split audio into individual instrument stems.
|
Split audio into individual instrument stems.
|
||||||
|
@ -99,7 +97,3 @@ def split_audio_cached(
|
||||||
segment: pydub.AudioSegment, device: str = "cuda"
|
segment: pydub.AudioSegment, device: str = "cuda"
|
||||||
) -> T.Dict[str, pydub.AudioSegment]:
|
) -> T.Dict[str, pydub.AudioSegment]:
|
||||||
return split_audio(segment, device=device)
|
return split_audio(segment, device=device)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_split_audio()
|
|
|
@ -6,10 +6,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
|
||||||
|
|
||||||
def render_text_to_audio() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("🌊 Text to Audio")
|
||||||
|
|
||||||
st.subheader(":pencil2: Text to Audio")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Generate audio from text prompts.
|
Generate audio from text prompts.
|
||||||
|
@ -119,7 +117,3 @@ def render_text_to_audio() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
seed += 1
|
seed += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_text_to_audio()
|
|
|
@ -14,7 +14,7 @@ EXAMPLE_INPUT = """
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"num_inference_steps": 50,
|
"num_inference_steps": 50,
|
||||||
"guidance": 7.0,
|
"guidance": 7.0,
|
||||||
"width": 512,
|
"width": 512
|
||||||
},
|
},
|
||||||
"entries": [
|
"entries": [
|
||||||
{
|
{
|
||||||
|
@ -32,10 +32,8 @@ EXAMPLE_INPUT = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def render_text_to_audio_batch() -> None:
|
def render() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.subheader("📜 Text to Audio Batch")
|
||||||
|
|
||||||
st.subheader(":scroll: Text to Audio Batch")
|
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Generate audio in batch from a JSON file of text prompts.
|
Generate audio in batch from a JSON file of text prompts.
|
||||||
|
@ -141,7 +139,3 @@ def render_text_to_audio_batch() -> None:
|
||||||
st.info(f"Output written to {str(output_path)}")
|
st.info(f"Output written to {str(output_path)}")
|
||||||
else:
|
else:
|
||||||
st.info("Enter output directory in sidebar to save to disk")
|
st.info("Enter output directory in sidebar to save to disk")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_text_to_audio_batch()
|
|
|
@ -33,7 +33,7 @@ SCHEDULER_OPTIONS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def load_riffusion_checkpoint(
|
def load_riffusion_checkpoint(
|
||||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
no_traced_unet: bool = False,
|
no_traced_unet: bool = False,
|
||||||
|
@ -49,7 +49,7 @@ def load_riffusion_checkpoint(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def load_stable_diffusion_pipeline(
|
def load_stable_diffusion_pipeline(
|
||||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
@ -109,7 +109,7 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
|
||||||
raise ValueError(f"Unknown scheduler {scheduler}")
|
raise ValueError(f"Unknown scheduler {scheduler}")
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def pipeline_lock() -> threading.Lock:
|
def pipeline_lock() -> threading.Lock:
|
||||||
"""
|
"""
|
||||||
Singleton lock used to prevent concurrent access to any model pipeline.
|
Singleton lock used to prevent concurrent access to any model pipeline.
|
||||||
|
@ -117,7 +117,7 @@ def pipeline_lock() -> threading.Lock:
|
||||||
return threading.Lock()
|
return threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def load_stable_diffusion_img2img_pipeline(
|
def load_stable_diffusion_img2img_pipeline(
|
||||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.cache_data
|
||||||
def run_txt2img(
|
def run_txt2img(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
@ -184,7 +184,7 @@ def run_txt2img(
|
||||||
return output["images"][0]
|
return output["images"][0]
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def spectrogram_image_converter(
|
def spectrogram_image_converter(
|
||||||
params: SpectrogramParams,
|
params: SpectrogramParams,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
@ -202,7 +202,7 @@ def spectrogram_image_from_audio(
|
||||||
return converter.spectrogram_image_from_audio(segment)
|
return converter.spectrogram_image_from_audio(segment)
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.cache_data
|
||||||
def audio_segment_from_spectrogram_image(
|
def audio_segment_from_spectrogram_image(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
params: SpectrogramParams,
|
params: SpectrogramParams,
|
||||||
|
@ -212,7 +212,7 @@ def audio_segment_from_spectrogram_image(
|
||||||
return converter.audio_from_spectrogram_image(image)
|
return converter.audio_from_spectrogram_image(image)
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.cache_data
|
||||||
def audio_bytes_from_spectrogram_image(
|
def audio_bytes_from_spectrogram_image(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
params: SpectrogramParams,
|
params: SpectrogramParams,
|
||||||
|
@ -289,17 +289,17 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
|
||||||
return custom_checkpoint or DEFAULT_CHECKPOINT
|
return custom_checkpoint or DEFAULT_CHECKPOINT
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.cache_data
|
||||||
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
||||||
return pydub.AudioSegment.from_file(audio_file)
|
return pydub.AudioSegment.from_file(audio_file)
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def get_audio_splitter(device: str = "cuda"):
|
def get_audio_splitter(device: str = "cuda"):
|
||||||
return AudioSplitter(device=device)
|
return AudioSplitter(device=device)
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.cache_resource
|
||||||
def load_magic_mix_pipeline(
|
def load_magic_mix_pipeline(
|
||||||
checkpoint: str = DEFAULT_CHECKPOINT,
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
|
Loading…
Reference in New Issue