Add several streamlit demo pages

Topic: streamlit_app
This commit is contained in:
Hayk Martiros 2022-12-27 00:25:19 -08:00
parent 1335afb72f
commit 152192006e
10 changed files with 425 additions and 200 deletions

View File

@ -116,9 +116,6 @@ def sample_clips(
if not output_dir_path.exists():
output_dir_path.mkdir(parents=True)
# TODO(hayk): Might be a lot easier with pydub
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)

View File

@ -1,25 +0,0 @@
import pydub
import streamlit as st
def run():
st.set_page_config(layout="wide", page_icon="🎸")
audio_file = st.file_uploader("Upload a file", type=["wav", "mp3", "ogg"])
if not audio_file:
st.info("Upload an audio file to get started")
return
st.audio(audio_file)
segment = pydub.AudioSegment.from_file(audio_file)
st.write(" \n".join([
f"**Duration**: {segment.duration_seconds:.3f} seconds",
f"**Channels**: {segment.channels}",
f"**Sample rate**: {segment.frame_rate} Hz",
f"**Sample width**: {segment.sample_width} bytes",
]))
if __name__ == "__main__":
run()

View File

@ -1,12 +1,26 @@
import dataclasses
import streamlit as st
from PIL import Image
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.util.image_util import exif_from_image
def render_image_to_audio() -> None:
image_file = st.sidebar.file_uploader(
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":musical_keyboard: Image to Audio")
st.write(
"""
Reconstruct audio from spectrogram images.
"""
)
device = streamlit_util.select_device(st.sidebar)
image_file = st.file_uploader(
"Upload a file",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed",
@ -18,29 +32,26 @@ def render_image_to_audio() -> None:
image = Image.open(image_file)
st.image(image)
with st.expander("Image metadata", expanded=False):
exif = exif_from_image(image)
st.write("Exif data:")
st.write(exif)
st.json(exif)
# device = "cuda"
try:
params = SpectrogramParams.from_exif(exif=image.getexif())
except KeyError:
st.info("Could not find spectrogram parameters in exif data. Using defaults.")
params = SpectrogramParams()
# try:
# params = SpectrogramParams.from_exif(exif=image.getexif())
# except KeyError:
# st.warning("Could not find spectrogram parameters in exif data. Using defaults.")
# params = SpectrogramParams()
with st.expander("Spectrogram Parameters", expanded=False):
st.json(dataclasses.asdict(params))
# segment = streamlit_util.audio_from_spectrogram_image(
# image=image,
# params=params,
# device=device,
# )
# mp3_bytes = io.BytesIO()
# segment.export(mp3_bytes, format="mp3")
# mp3_bytes.seek(0)
# st.audio(mp3_bytes)
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image.copy(),
params=params,
device=device,
output_format="mp3",
)
st.audio(audio_bytes)
if __name__ == "__main__":

View File

@ -0,0 +1,197 @@
import dataclasses
import io
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
from PIL import Image
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_interpolation_demo() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":performing_arts: Interpolation")
st.write(
"""
Interpolate between prompts in the latent space.
"""
)
# Sidebar params
device = streamlit_util.select_device(st.sidebar)
num_interpolation_steps = T.cast(
int,
st.sidebar.number_input(
"Interpolation steps",
value=4,
min_value=1,
max_value=20,
help="Number of model generations between the two prompts. Controls the duration.",
),
)
num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
init_image_name = st.sidebar.selectbox(
"Seed image",
# TODO(hayk): Read from directory
options=["og_beat", "agile", "marim", "motorway", "vibes"],
index=0,
help="Which seed image to use for img2img",
)
assert init_image_name is not None
show_individual_outputs = st.sidebar.checkbox(
"Show individual outputs",
value=False,
help="Show each model output",
)
show_images = st.sidebar.checkbox(
"Show individual images",
value=False,
help="Show each generated image",
)
# Prompt inputs A and B in two columns
left, right = st.columns(2)
with left.expander("Input A", expanded=True):
prompt_input_a = get_prompt_inputs(key="a")
with right.expander("Input B", expanded=True):
prompt_input_b = get_prompt_inputs(key="b")
if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them")
return
alphas = list(np.linspace(0, 1, num_interpolation_steps))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
# TODO(hayk): Upload your own seed image.
init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
)
init_image = Image.open(str(init_image_path)).convert("RGB")
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
image_list: T.List[Image.Image] = []
audio_bytes_list: T.List[io.BytesIO] = []
for i, alpha in enumerate(alphas):
inputs = InferenceInput(
alpha=float(alpha),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)
if i == 0:
with st.expander("Example input JSON", expanded=False):
st.json(dataclasses.asdict(inputs))
image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image,
device=device,
)
if show_individual_outputs:
st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
if show_images:
st.image(image)
st.audio(audio_bytes)
image_list.append(image)
audio_bytes_list.append(audio_bytes)
st.write("#### Final Output")
# TODO(hayk): Concatenate with better blending
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
concat_segment = audio_segments[0]
for segment in audio_segments[1:]:
concat_segment = concat_segment.append(segment, crossfade=0)
audio_bytes = io.BytesIO()
concat_segment.export(audio_bytes, format="mp3")
audio_bytes.seek(0)
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
st.audio(audio_bytes)
def get_prompt_inputs(key: str) -> PromptInput:
"""
Compute prompt inputs from widgets.
"""
prompt = st.text_input("Prompt", label_visibility="collapsed", key=f"prompt_{key}")
seed = T.cast(int, st.number_input("Seed", value=42, key=f"seed_{key}"))
denoising = st.number_input(
"Denoising", value=0.75, key=f"denoising_{key}", help="How much to modify the seed image"
)
guidance = st.number_input(
"Guidance",
value=7.0,
key=f"guidance_{key}",
help="How much the model listens to the text prompt",
)
return PromptInput(
prompt=prompt,
seed=seed,
denoising=denoising,
guidance=guidance,
)
@st.experimental_memo
def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
) -> T.Tuple[Image.Image, io.BytesIO]:
"""
Cached function for riffusion interpolation.
"""
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=None,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct from image to audio
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format="mp3",
)
return image, audio_bytes
if __name__ == "__main__":
render_interpolation_demo()

View File

@ -1,97 +0,0 @@
import io
from pathlib import Path
import dacite
import streamlit as st
import torch
from PIL import Image
from riffusion.datatypes import InferenceInput
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render_interpolation_demo() -> None:
"""
Render audio from text.
"""
prompt = st.text_input("Prompt", label_visibility="collapsed")
if not prompt:
st.info("Enter a prompt")
return
seed = st.sidebar.number_input("Seed", value=42)
denoising = st.sidebar.number_input("Denoising", value=0.01)
guidance = st.sidebar.number_input("Guidance", value=7.0)
num_inference_steps = st.sidebar.number_input("Inference steps", value=50)
default_device = "cpu"
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
device_options = ["cuda", "cpu", "mps"]
device = st.sidebar.selectbox(
"Device", options=device_options, index=device_options.index(default_device)
)
assert device is not None
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
input_dict = {
"alpha": 0.75,
"num_inference_steps": num_inference_steps,
"seed_image_id": "og_beat",
"start": {
"prompt": prompt,
"seed": seed,
"denoising": denoising,
"guidance": guidance,
},
"end": {
"prompt": prompt,
"seed": seed,
"denoising": denoising,
"guidance": guidance,
},
}
st.json(input_dict)
inputs = dacite.from_dict(InferenceInput, input_dict)
# TODO fix
init_image_path = Path(__file__).parent.parent.parent.parent / "seed_images" / "og_beat.png"
init_image = Image.open(str(init_image_path)).convert("RGB")
# Execute the model to get the spectrogram image
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=None,
)
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct audio from the image
# TODO(hayk): It may help performance to cache this object
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
segment = converter.audio_from_spectrogram_image(
image,
apply_filters=True,
)
mp3_bytes = io.BytesIO()
segment.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
st.audio(mp3_bytes)
if __name__ == "__main__":
render_interpolation_demo()

View File

@ -0,0 +1,85 @@
import tempfile
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
def render_sample_clips() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":scissors: Sample Clips")
st.write(
"""
Export short clips from an audio file.
"""
)
audio_file = st.file_uploader(
"Upload a file",
type=["wav", "mp3", "ogg"],
label_visibility="collapsed",
)
if not audio_file:
st.info("Upload an audio file to get started")
return
st.audio(audio_file)
segment = pydub.AudioSegment.from_file(audio_file)
st.write(
" \n".join(
[
f"**Duration**: {segment.duration_seconds:.3f} seconds",
f"**Channels**: {segment.channels}",
f"**Sample rate**: {segment.frame_rate} Hz",
f"**Sample width**: {segment.sample_width} bytes",
]
)
)
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000))
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3))
extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"])
assert extension is not None
# Optionally specify an output directory
output_dir = st.text_input("Output Directory")
if not output_dir:
tmp_dir = tempfile.mkdtemp(prefix="sample_clips_")
st.info(f"Specify an output directory. Suggested: `{tmp_dir}`")
return
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
if seed >= 0:
np.random.seed(seed)
if export_as_mono and segment.channels > 1:
segment = segment.set_channels(1)
# TODO(hayk): Share code with riffusion.cli.sample_clips.
segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
clip_path = output_path / clip_name
clip.export(clip_path, format=extension)
st.audio(str(clip_path))
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
if __name__ == "__main__":
render_sample_clips()

View File

@ -7,19 +7,25 @@ from riffusion.streamlit import util as streamlit_util
def render_text_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":pencil2: Text to Audio")
st.write(
"""
Render audio from text.
Generate audio from text prompts. \nRuns the model directly without a seed image or
interpolation.
"""
prompt = st.text_input("Prompt")
negative_prompt = st.text_input("Negative prompt")
)
device = streamlit_util.select_device(st.sidebar)
prompt = st.text_input("Prompt")
negative_prompt = st.text_input("Negative prompt")
with st.sidebar.expander("Text to Audio Params", expanded=True):
seed = T.cast(int, st.number_input("Seed", value=42))
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
width = T.cast(int, st.number_input("Width", value=512))
height = T.cast(int, st.number_input("Height", value=512))
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
@ -35,9 +41,10 @@ def render_text_to_audio() -> None:
negative_prompt=negative_prompt,
seed=seed,
width=width,
height=height,
height=512,
device=device,
)
st.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained

View File

@ -7,27 +7,67 @@ import streamlit as st
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
# Example input json file to process in batch
EXAMPLE_INPUT = """
{
"params": {
"seed": 42,
"num_inference_steps": 50,
"guidance": 7.0,
"width": 512,
},
"entries": [
{
"prompt": "Church bells"
},
{
"prompt": "electronic beats",
"negative_prompt": "drums"
},
{
"prompt": "classical violin concerto"
}
]
}
"""
def render_text_to_audio_batch() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":scroll: Text to Audio Batch")
st.write(
"""
Render audio from text in batches, reading from a text file.
Generate audio in batch from a JSON file of text prompts. \nThe input
file contains a global params block and a list of entries with positive and negative
prompts.
"""
json_file = st.file_uploader("JSON file", type=["json"])
)
device = streamlit_util.select_device(st.sidebar)
# Upload a JSON file
json_file = st.file_uploader(
"JSON file",
type=["json"],
label_visibility="collapsed",
)
# Handle the null case
if json_file is None:
st.info("Upload a JSON file of prompts")
st.info("Upload a JSON file containing params and prompts")
with st.expander("Example inputs.json", expanded=False):
st.code(EXAMPLE_INPUT)
return
# Read in and print it
data = json.loads(json_file.read())
with st.expander("Data", expanded=False):
with st.expander("Input Data", expanded=False):
st.json(data)
params = data["params"]
entries = data["entries"]
device = streamlit_util.select_device(st.sidebar)
show_images = st.sidebar.checkbox("Show Images", True)
show_images = st.sidebar.checkbox("Show Images", False)
# Optionally specify an output directory
output_dir = st.sidebar.text_input("Output Directory", "")
@ -37,21 +77,20 @@ def render_text_to_audio_batch() -> None:
output_path.mkdir(parents=True, exist_ok=True)
for i, entry in enumerate(entries):
st.write(f"### Entry {i + 1} / {len(entries)}")
st.write(f"Prompt: {entry['prompt']}")
st.write(f"#### Entry {i + 1} / {len(entries)}")
negative_prompt = entry.get("negative_prompt", None)
st.write(f"Negative prompt: {negative_prompt}")
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
image = streamlit_util.run_txt2img(
prompt=entry["prompt"],
negative_prompt=negative_prompt,
seed=params["seed"],
num_inference_steps=params["num_inference_steps"],
guidance=params["guidance"],
width=params["width"],
height=params["height"],
seed=params.get("seed", 42),
num_inference_steps=params.get("num_inference_steps", 50),
guidance=params.get("guidance", 7.0),
width=params.get("width", 512),
height=512,
device=device,
)
@ -91,6 +130,8 @@ def render_text_to_audio_batch() -> None:
output_json_path = output_path / "index.json"
output_json_path.write_text(json.dumps(data, indent=4))
st.info(f"Output written to {str(output_path)}")
else:
st.info("Enter output directory in sidebar to save to disk")
if __name__ == "__main__":

View File

@ -0,0 +1,37 @@
import streamlit as st
def render_main():
st.set_page_config(layout="wide", page_icon="🎸")
st.header(":guitar: Riffusion Playground")
st.write("Interactive app for common riffusion tasks.")
left, right = st.columns(2)
with left:
create_link(":performing_arts: Interpolation", "/interpolation")
st.write("Interpolate between prompts in the latent space.")
create_link(":pencil2: Text to Audio", "/text_to_audio")
st.write("Generate audio from text prompts.")
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
st.write("Generate audio in batch from a JSON file of text prompts.")
with right:
create_link(":scissors: Sample Clips", "/sample_clips")
st.write("Export short clips from an audio file.")
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
st.write("Reconstruct audio from spectrogram images.")
def create_link(name: str, url: str) -> None:
st.markdown(
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
unsafe_allow_html=True,
)
if __name__ == "__main__":
render_main()

View File

@ -13,6 +13,8 @@ from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
# TODO(hayk): Add URL params
@st.experimental_singleton
def load_riffusion_checkpoint(
@ -53,7 +55,6 @@ def load_stable_diffusion_pipeline(
).to(device)
@st.experimental_memo
def run_txt2img(
prompt: str,
@ -86,25 +87,6 @@ def run_txt2img(
return output["images"][0]
# class CachedSpectrogramImageConverter:
# def __init__(self, params: SpectrogramParams, device: str = "cuda"):
# self.p = params
# self.device = device
# self.converter = self._converter(params, device)
# @staticmethod
# @st.experimental_singleton
# def _converter(params: SpectrogramParams, device: str) -> SpectrogramImageConverter:
# return SpectrogramImageConverter(params=params, device=device)
# def audio_from_spectrogram_image(
# self,
# image: Image.Image
# ) -> pydub.AudioSegment:
# return self._converter.audio_from_spectrogram_image(image)
@st.experimental_singleton
def spectrogram_image_converter(
params: SpectrogramParams,
@ -147,13 +129,3 @@ def select_device(container: T.Any = st.sidebar) -> str:
assert device is not None
return device
# @st.experimental_memo
# def spectrogram_image_from_audio(
# segment: pydub.AudioSegment,
# params: SpectrogramParams,
# device: str = "cuda",
# ) -> pydub.AudioSegment:
# converter = spectrogram_image_converter(params=params, device=device)
# return converter.spectrogram_image_from_audio(segment)