Streamlit app for interactive use of the model
Topic: streamlit_app
This commit is contained in:
parent
e8b99fabf9
commit
39dc247a1d
|
@ -0,0 +1,3 @@
|
|||
# streamlit
|
||||
|
||||
This package is an interactive streamlit app for riffusion.
|
|
@ -0,0 +1,25 @@
|
|||
import pydub
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def run():
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
audio_file = st.file_uploader("Upload a file", type=["wav", "mp3", "ogg"])
|
||||
if not audio_file:
|
||||
st.info("Upload an audio file to get started")
|
||||
return
|
||||
|
||||
st.audio(audio_file)
|
||||
|
||||
segment = pydub.AudioSegment.from_file(audio_file)
|
||||
st.write(" \n".join([
|
||||
f"**Duration**: {segment.duration_seconds:.3f} seconds",
|
||||
f"**Channels**: {segment.channels}",
|
||||
f"**Sample rate**: {segment.frame_rate} Hz",
|
||||
f"**Sample width**: {segment.sample_width} bytes",
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
|
@ -0,0 +1,51 @@
|
|||
import io
|
||||
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.util.image_util import exif_from_image
|
||||
|
||||
|
||||
def render_image_to_audio() -> None:
|
||||
image_file = st.sidebar.file_uploader(
|
||||
"Upload a file",
|
||||
type=["png", "jpg", "jpeg"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not image_file:
|
||||
st.info("Upload an image file to get started")
|
||||
return
|
||||
|
||||
image = Image.open(image_file)
|
||||
st.image(image)
|
||||
|
||||
exif = exif_from_image(image)
|
||||
st.write("Exif data:")
|
||||
st.write(exif)
|
||||
|
||||
device = "cuda"
|
||||
|
||||
try:
|
||||
params = SpectrogramParams.from_exif(exif=image.getexif())
|
||||
except KeyError:
|
||||
st.warning("Could not find spectrogram parameters in exif data. Using defaults.")
|
||||
params = SpectrogramParams()
|
||||
|
||||
# segment = streamlit_util.audio_from_spectrogram_image(
|
||||
# image=image,
|
||||
# params=params,
|
||||
# device=device,
|
||||
# )
|
||||
|
||||
# mp3_bytes = io.BytesIO()
|
||||
# segment.export(mp3_bytes, format="mp3")
|
||||
# mp3_bytes.seek(0)
|
||||
|
||||
# st.audio(mp3_bytes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_image_to_audio()
|
|
@ -0,0 +1,97 @@
|
|||
import io
|
||||
from pathlib import Path
|
||||
|
||||
import dacite
|
||||
import streamlit as st
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_interpolation_demo() -> None:
|
||||
"""
|
||||
Render audio from text.
|
||||
"""
|
||||
prompt = st.text_input("Prompt", label_visibility="collapsed")
|
||||
if not prompt:
|
||||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
seed = st.sidebar.number_input("Seed", value=42)
|
||||
denoising = st.sidebar.number_input("Denoising", value=0.01)
|
||||
guidance = st.sidebar.number_input("Guidance", value=7.0)
|
||||
num_inference_steps = st.sidebar.number_input("Inference steps", value=50)
|
||||
|
||||
default_device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
default_device = "mps"
|
||||
|
||||
device_options = ["cuda", "cpu", "mps"]
|
||||
device = st.sidebar.selectbox(
|
||||
"Device", options=device_options, index=device_options.index(default_device)
|
||||
)
|
||||
assert device is not None
|
||||
|
||||
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
|
||||
|
||||
input_dict = {
|
||||
"alpha": 0.75,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"seed_image_id": "og_beat",
|
||||
"start": {
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"denoising": denoising,
|
||||
"guidance": guidance,
|
||||
},
|
||||
"end": {
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"denoising": denoising,
|
||||
"guidance": guidance,
|
||||
},
|
||||
}
|
||||
st.json(input_dict)
|
||||
|
||||
inputs = dacite.from_dict(InferenceInput, input_dict)
|
||||
|
||||
# TODO fix
|
||||
init_image_path = Path(__file__).parent.parent.parent.parent / "seed_images" / "og_beat.png"
|
||||
init_image = Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
init_image=init_image,
|
||||
mask_image=None,
|
||||
)
|
||||
st.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
# TODO(hayk): It may help performance to cache this object
|
||||
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
|
||||
segment = converter.audio_from_spectrogram_image(
|
||||
image,
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
mp3_bytes = io.BytesIO()
|
||||
segment.export(mp3_bytes, format="mp3")
|
||||
mp3_bytes.seek(0)
|
||||
st.audio(mp3_bytes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_interpolation_demo()
|
|
@ -0,0 +1,130 @@
|
|||
import io
|
||||
from pathlib import Path
|
||||
|
||||
import dacite
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import streamlit as st
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_stable_diffusion_pipeline(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
return StableDiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def run_txt2img(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance: float,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
width: int,
|
||||
height: int,
|
||||
device: str = "cuda",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Run the text to image pipeline with caching.
|
||||
"""
|
||||
pipeline = load_stable_diffusion_pipeline(device=device)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance,
|
||||
negative_prompt=negative_prompt or None,
|
||||
generator=generator,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
return output["images"][0]
|
||||
|
||||
|
||||
def render_text_to_audio() -> None:
|
||||
"""
|
||||
Render audio from text.
|
||||
"""
|
||||
prompt = st.text_input("Prompt")
|
||||
if not prompt:
|
||||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
negative_prompt = st.text_input("Negative prompt")
|
||||
seed = st.sidebar.number_input("Seed", value=42)
|
||||
num_inference_steps = st.sidebar.number_input("Inference steps", value=20)
|
||||
width = st.sidebar.number_input("Width", value=512)
|
||||
height = st.sidebar.number_input("Height", value=512)
|
||||
guidance = st.sidebar.number_input(
|
||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||
)
|
||||
|
||||
default_device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
default_device = "mps"
|
||||
|
||||
device_options = ["cuda", "cpu", "mps"]
|
||||
device = st.sidebar.selectbox(
|
||||
"Device", options=device_options, index=device_options.index(default_device)
|
||||
)
|
||||
assert device is not None
|
||||
|
||||
image = run_txt2img(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance=guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
)
|
||||
|
||||
st.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segment = streamlit_util.audio_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
mp3_bytes = io.BytesIO()
|
||||
segment.export(mp3_bytes, format="mp3")
|
||||
mp3_bytes.seek(0)
|
||||
st.audio(mp3_bytes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio()
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Streamlit utilities (mostly cached wrappers around riffusion code).
|
||||
"""
|
||||
|
||||
import pydub
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_riffusion_checkpoint(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
no_traced_unet: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
"""
|
||||
return RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=checkpoint,
|
||||
use_traced_unet=not no_traced_unet,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# class CachedSpectrogramImageConverter:
|
||||
|
||||
# def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
# self.p = params
|
||||
# self.device = device
|
||||
# self.converter = self._converter(params, device)
|
||||
|
||||
# @staticmethod
|
||||
# @st.experimental_singleton
|
||||
# def _converter(params: SpectrogramParams, device: str) -> SpectrogramImageConverter:
|
||||
# return SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
# def audio_from_spectrogram_image(
|
||||
# self,
|
||||
# image: Image.Image
|
||||
# ) -> pydub.AudioSegment:
|
||||
# return self._converter.audio_from_spectrogram_image(image)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def spectrogram_image_converter(
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
) -> SpectrogramImageConverter:
|
||||
return SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def audio_from_spectrogram_image(
|
||||
image: Image.Image,
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
) -> pydub.AudioSegment:
|
||||
converter = spectrogram_image_converter(params=params, device=device)
|
||||
return converter.audio_from_spectrogram_image(image)
|
||||
|
||||
|
||||
# @st.experimental_memo
|
||||
# def spectrogram_image_from_audio(
|
||||
# segment: pydub.AudioSegment,
|
||||
# params: SpectrogramParams,
|
||||
# device: str = "cuda",
|
||||
# ) -> pydub.AudioSegment:
|
||||
# converter = spectrogram_image_converter(params=params, device=device)
|
||||
# return converter.spectrogram_image_from_audio(segment)
|
Loading…
Reference in New Issue