From 152192006ee0dcf0804672b37216af36eb257bbd Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Tue, 27 Dec 2022 00:25:19 -0800 Subject: [PATCH] Add several streamlit demo pages Topic: streamlit_app --- riffusion/cli.py | 3 - riffusion/streamlit/main.py | 25 --- riffusion/streamlit/pages/image_to_audio.py | 53 +++-- riffusion/streamlit/pages/interpolation.py | 197 ++++++++++++++++++ .../streamlit/pages/interpolation_demo.py | 97 --------- riffusion/streamlit/pages/sample_clips.py | 85 ++++++++ riffusion/streamlit/pages/text_to_audio.py | 19 +- .../streamlit/pages/text_to_audio_batch.py | 77 +++++-- riffusion/streamlit/playground.py | 37 ++++ riffusion/streamlit/util.py | 32 +-- 10 files changed, 425 insertions(+), 200 deletions(-) delete mode 100644 riffusion/streamlit/main.py create mode 100644 riffusion/streamlit/pages/interpolation.py delete mode 100644 riffusion/streamlit/pages/interpolation_demo.py create mode 100644 riffusion/streamlit/pages/sample_clips.py create mode 100644 riffusion/streamlit/playground.py diff --git a/riffusion/cli.py b/riffusion/cli.py index 221848d..8fec114 100644 --- a/riffusion/cli.py +++ b/riffusion/cli.py @@ -116,9 +116,6 @@ def sample_clips( if not output_dir_path.exists(): output_dir_path.mkdir(parents=True) - # TODO(hayk): Might be a lot easier with pydub - # https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentfrom_file - segment_duration_ms = int(segment.duration_seconds * 1000) for i in range(num_clips): clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) diff --git a/riffusion/streamlit/main.py b/riffusion/streamlit/main.py deleted file mode 100644 index c196417..0000000 --- a/riffusion/streamlit/main.py +++ /dev/null @@ -1,25 +0,0 @@ -import pydub -import streamlit as st - - -def run(): - st.set_page_config(layout="wide", page_icon="🎸") - - audio_file = st.file_uploader("Upload a file", type=["wav", "mp3", "ogg"]) - if not audio_file: - st.info("Upload an audio file to get started") - return - - st.audio(audio_file) - - segment = pydub.AudioSegment.from_file(audio_file) - st.write(" \n".join([ - f"**Duration**: {segment.duration_seconds:.3f} seconds", - f"**Channels**: {segment.channels}", - f"**Sample rate**: {segment.frame_rate} Hz", - f"**Sample width**: {segment.sample_width} bytes", - ])) - - -if __name__ == "__main__": - run() diff --git a/riffusion/streamlit/pages/image_to_audio.py b/riffusion/streamlit/pages/image_to_audio.py index eb57211..cfa0814 100644 --- a/riffusion/streamlit/pages/image_to_audio.py +++ b/riffusion/streamlit/pages/image_to_audio.py @@ -1,12 +1,26 @@ +import dataclasses import streamlit as st from PIL import Image +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.streamlit import util as streamlit_util from riffusion.util.image_util import exif_from_image def render_image_to_audio() -> None: - image_file = st.sidebar.file_uploader( + st.set_page_config(layout="wide", page_icon="🎸") + + st.subheader(":musical_keyboard: Image to Audio") + st.write( + """ + Reconstruct audio from spectrogram images. + """ + ) + + device = streamlit_util.select_device(st.sidebar) + + image_file = st.file_uploader( "Upload a file", type=["png", "jpg", "jpeg"], label_visibility="collapsed", @@ -18,29 +32,26 @@ def render_image_to_audio() -> None: image = Image.open(image_file) st.image(image) - exif = exif_from_image(image) - st.write("Exif data:") - st.write(exif) + with st.expander("Image metadata", expanded=False): + exif = exif_from_image(image) + st.json(exif) - # device = "cuda" + try: + params = SpectrogramParams.from_exif(exif=image.getexif()) + except KeyError: + st.info("Could not find spectrogram parameters in exif data. Using defaults.") + params = SpectrogramParams() - # try: - # params = SpectrogramParams.from_exif(exif=image.getexif()) - # except KeyError: - # st.warning("Could not find spectrogram parameters in exif data. Using defaults.") - # params = SpectrogramParams() + with st.expander("Spectrogram Parameters", expanded=False): + st.json(dataclasses.asdict(params)) - # segment = streamlit_util.audio_from_spectrogram_image( - # image=image, - # params=params, - # device=device, - # ) - - # mp3_bytes = io.BytesIO() - # segment.export(mp3_bytes, format="mp3") - # mp3_bytes.seek(0) - - # st.audio(mp3_bytes) + audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + image=image.copy(), + params=params, + device=device, + output_format="mp3", + ) + st.audio(audio_bytes) if __name__ == "__main__": diff --git a/riffusion/streamlit/pages/interpolation.py b/riffusion/streamlit/pages/interpolation.py new file mode 100644 index 0000000..800c7ba --- /dev/null +++ b/riffusion/streamlit/pages/interpolation.py @@ -0,0 +1,197 @@ +import dataclasses +import io +import typing as T +from pathlib import Path + +import numpy as np +import pydub +import streamlit as st +from PIL import Image + +from riffusion.datatypes import InferenceInput, PromptInput +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.streamlit import util as streamlit_util + + +def render_interpolation_demo() -> None: + st.set_page_config(layout="wide", page_icon="🎸") + + st.subheader(":performing_arts: Interpolation") + st.write( + """ + Interpolate between prompts in the latent space. + """ + ) + + # Sidebar params + + device = streamlit_util.select_device(st.sidebar) + + num_interpolation_steps = T.cast( + int, + st.sidebar.number_input( + "Interpolation steps", + value=4, + min_value=1, + max_value=20, + help="Number of model generations between the two prompts. Controls the duration.", + ), + ) + + num_inference_steps = T.cast( + int, + st.sidebar.number_input( + "Steps per sample", value=50, help="Number of denoising steps per model run" + ), + ) + + init_image_name = st.sidebar.selectbox( + "Seed image", + # TODO(hayk): Read from directory + options=["og_beat", "agile", "marim", "motorway", "vibes"], + index=0, + help="Which seed image to use for img2img", + ) + assert init_image_name is not None + + show_individual_outputs = st.sidebar.checkbox( + "Show individual outputs", + value=False, + help="Show each model output", + ) + show_images = st.sidebar.checkbox( + "Show individual images", + value=False, + help="Show each generated image", + ) + + # Prompt inputs A and B in two columns + + left, right = st.columns(2) + + with left.expander("Input A", expanded=True): + prompt_input_a = get_prompt_inputs(key="a") + + with right.expander("Input B", expanded=True): + prompt_input_b = get_prompt_inputs(key="b") + + if not prompt_input_a.prompt or not prompt_input_b.prompt: + st.info("Enter both prompts to interpolate between them") + return + + alphas = list(np.linspace(0, 1, num_interpolation_steps)) + alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) + st.write(f"**Alphas** : [{alphas_str}]") + + # TODO(hayk): Upload your own seed image. + + init_image_path = ( + Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png" + ) + init_image = Image.open(str(init_image_path)).convert("RGB") + + # TODO(hayk): Move this code into a shared place and add to riffusion.cli + image_list: T.List[Image.Image] = [] + audio_bytes_list: T.List[io.BytesIO] = [] + for i, alpha in enumerate(alphas): + inputs = InferenceInput( + alpha=float(alpha), + num_inference_steps=num_inference_steps, + seed_image_id="og_beat", + start=prompt_input_a, + end=prompt_input_b, + ) + + if i == 0: + with st.expander("Example input JSON", expanded=False): + st.json(dataclasses.asdict(inputs)) + + image, audio_bytes = run_interpolation( + inputs=inputs, + init_image=init_image, + device=device, + ) + + if show_individual_outputs: + st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}") + if show_images: + st.image(image) + st.audio(audio_bytes) + + image_list.append(image) + audio_bytes_list.append(audio_bytes) + + st.write("#### Final Output") + + # TODO(hayk): Concatenate with better blending + audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list] + concat_segment = audio_segments[0] + for segment in audio_segments[1:]: + concat_segment = concat_segment.append(segment, crossfade=0) + + audio_bytes = io.BytesIO() + concat_segment.export(audio_bytes, format="mp3") + audio_bytes.seek(0) + + st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds") + st.audio(audio_bytes) + + +def get_prompt_inputs(key: str) -> PromptInput: + """ + Compute prompt inputs from widgets. + """ + prompt = st.text_input("Prompt", label_visibility="collapsed", key=f"prompt_{key}") + seed = T.cast(int, st.number_input("Seed", value=42, key=f"seed_{key}")) + denoising = st.number_input( + "Denoising", value=0.75, key=f"denoising_{key}", help="How much to modify the seed image" + ) + guidance = st.number_input( + "Guidance", + value=7.0, + key=f"guidance_{key}", + help="How much the model listens to the text prompt", + ) + + return PromptInput( + prompt=prompt, + seed=seed, + denoising=denoising, + guidance=guidance, + ) + + +@st.experimental_memo +def run_interpolation( + inputs: InferenceInput, init_image: Image.Image, device: str = "cuda" +) -> T.Tuple[Image.Image, io.BytesIO]: + """ + Cached function for riffusion interpolation. + """ + pipeline = streamlit_util.load_riffusion_checkpoint(device=device) + + image = pipeline.riffuse( + inputs, + init_image=init_image, + mask_image=None, + ) + + # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained + params = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + ) + + # Reconstruct from image to audio + audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + image=image, + params=params, + device=device, + output_format="mp3", + ) + + return image, audio_bytes + + +if __name__ == "__main__": + render_interpolation_demo() diff --git a/riffusion/streamlit/pages/interpolation_demo.py b/riffusion/streamlit/pages/interpolation_demo.py deleted file mode 100644 index 1c0720e..0000000 --- a/riffusion/streamlit/pages/interpolation_demo.py +++ /dev/null @@ -1,97 +0,0 @@ -import io -from pathlib import Path - -import dacite -import streamlit as st -import torch -from PIL import Image - -from riffusion.datatypes import InferenceInput -from riffusion.spectrogram_image_converter import SpectrogramImageConverter -from riffusion.spectrogram_params import SpectrogramParams -from riffusion.streamlit import util as streamlit_util - - -def render_interpolation_demo() -> None: - """ - Render audio from text. - """ - prompt = st.text_input("Prompt", label_visibility="collapsed") - if not prompt: - st.info("Enter a prompt") - return - - seed = st.sidebar.number_input("Seed", value=42) - denoising = st.sidebar.number_input("Denoising", value=0.01) - guidance = st.sidebar.number_input("Guidance", value=7.0) - num_inference_steps = st.sidebar.number_input("Inference steps", value=50) - - default_device = "cpu" - if torch.cuda.is_available(): - default_device = "cuda" - elif torch.backends.mps.is_available(): - default_device = "mps" - - device_options = ["cuda", "cpu", "mps"] - device = st.sidebar.selectbox( - "Device", options=device_options, index=device_options.index(default_device) - ) - assert device is not None - - pipeline = streamlit_util.load_riffusion_checkpoint(device=device) - - input_dict = { - "alpha": 0.75, - "num_inference_steps": num_inference_steps, - "seed_image_id": "og_beat", - "start": { - "prompt": prompt, - "seed": seed, - "denoising": denoising, - "guidance": guidance, - }, - "end": { - "prompt": prompt, - "seed": seed, - "denoising": denoising, - "guidance": guidance, - }, - } - st.json(input_dict) - - inputs = dacite.from_dict(InferenceInput, input_dict) - - # TODO fix - init_image_path = Path(__file__).parent.parent.parent.parent / "seed_images" / "og_beat.png" - init_image = Image.open(str(init_image_path)).convert("RGB") - - # Execute the model to get the spectrogram image - image = pipeline.riffuse( - inputs, - init_image=init_image, - mask_image=None, - ) - st.image(image) - - # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained - params = SpectrogramParams( - min_frequency=0, - max_frequency=10000, - ) - - # Reconstruct audio from the image - # TODO(hayk): It may help performance to cache this object - converter = SpectrogramImageConverter(params=params, device=str(pipeline.device)) - segment = converter.audio_from_spectrogram_image( - image, - apply_filters=True, - ) - - mp3_bytes = io.BytesIO() - segment.export(mp3_bytes, format="mp3") - mp3_bytes.seek(0) - st.audio(mp3_bytes) - - -if __name__ == "__main__": - render_interpolation_demo() diff --git a/riffusion/streamlit/pages/sample_clips.py b/riffusion/streamlit/pages/sample_clips.py new file mode 100644 index 0000000..374d3f0 --- /dev/null +++ b/riffusion/streamlit/pages/sample_clips.py @@ -0,0 +1,85 @@ +import tempfile +import typing as T +from pathlib import Path + +import numpy as np +import pydub +import streamlit as st + + +def render_sample_clips() -> None: + st.set_page_config(layout="wide", page_icon="🎸") + + st.subheader(":scissors: Sample Clips") + st.write( + """ + Export short clips from an audio file. + """ + ) + + audio_file = st.file_uploader( + "Upload a file", + type=["wav", "mp3", "ogg"], + label_visibility="collapsed", + ) + if not audio_file: + st.info("Upload an audio file to get started") + return + + st.audio(audio_file) + + segment = pydub.AudioSegment.from_file(audio_file) + st.write( + " \n".join( + [ + f"**Duration**: {segment.duration_seconds:.3f} seconds", + f"**Channels**: {segment.channels}", + f"**Sample rate**: {segment.frame_rate} Hz", + f"**Sample width**: {segment.sample_width} bytes", + ] + ) + ) + + seed = T.cast(int, st.sidebar.number_input("Seed", value=42)) + duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000)) + export_as_mono = st.sidebar.checkbox("Export as Mono", False) + num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3)) + extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"]) + assert extension is not None + + # Optionally specify an output directory + output_dir = st.text_input("Output Directory") + if not output_dir: + tmp_dir = tempfile.mkdtemp(prefix="sample_clips_") + st.info(f"Specify an output directory. Suggested: `{tmp_dir}`") + return + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + if seed >= 0: + np.random.seed(seed) + + if export_as_mono and segment.channels > 1: + segment = segment.set_channels(1) + + # TODO(hayk): Share code with riffusion.cli.sample_clips. + segment_duration_ms = int(segment.duration_seconds * 1000) + for i in range(num_clips): + clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) + clip = segment[clip_start_ms : clip_start_ms + duration_ms] + + clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}" + + st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`") + + clip_path = output_path / clip_name + clip.export(clip_path, format=extension) + + st.audio(str(clip_path)) + + st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`") + + +if __name__ == "__main__": + render_sample_clips() diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 7d328fb..16cfa71 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -7,19 +7,25 @@ from riffusion.streamlit import util as streamlit_util def render_text_to_audio() -> None: + st.set_page_config(layout="wide", page_icon="🎸") + + st.subheader(":pencil2: Text to Audio") + st.write( + """ + Generate audio from text prompts. \nRuns the model directly without a seed image or + interpolation. """ - Render audio from text. - """ - prompt = st.text_input("Prompt") - negative_prompt = st.text_input("Negative prompt") + ) device = streamlit_util.select_device(st.sidebar) + prompt = st.text_input("Prompt") + negative_prompt = st.text_input("Negative prompt") + with st.sidebar.expander("Text to Audio Params", expanded=True): seed = T.cast(int, st.number_input("Seed", value=42)) num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50)) width = T.cast(int, st.number_input("Width", value=512)) - height = T.cast(int, st.number_input("Height", value=512)) guidance = st.number_input( "Guidance", value=7.0, help="How much the model listens to the text prompt" ) @@ -35,9 +41,10 @@ def render_text_to_audio() -> None: negative_prompt=negative_prompt, seed=seed, width=width, - height=height, + height=512, device=device, ) + st.image(image) # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained diff --git a/riffusion/streamlit/pages/text_to_audio_batch.py b/riffusion/streamlit/pages/text_to_audio_batch.py index 139df62..ee2da65 100644 --- a/riffusion/streamlit/pages/text_to_audio_batch.py +++ b/riffusion/streamlit/pages/text_to_audio_batch.py @@ -7,27 +7,67 @@ import streamlit as st from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util +# Example input json file to process in batch +EXAMPLE_INPUT = """ +{ + "params": { + "seed": 42, + "num_inference_steps": 50, + "guidance": 7.0, + "width": 512, + }, + "entries": [ + { + "prompt": "Church bells" + }, + { + "prompt": "electronic beats", + "negative_prompt": "drums" + }, + { + "prompt": "classical violin concerto" + } + ] +} +""" + def render_text_to_audio_batch() -> None: + st.set_page_config(layout="wide", page_icon="🎸") + + st.subheader(":scroll: Text to Audio Batch") + st.write( + """ + Generate audio in batch from a JSON file of text prompts. \nThe input + file contains a global params block and a list of entries with positive and negative + prompts. """ - Render audio from text in batches, reading from a text file. - """ - json_file = st.file_uploader("JSON file", type=["json"]) + ) + device = streamlit_util.select_device(st.sidebar) + + # Upload a JSON file + json_file = st.file_uploader( + "JSON file", + type=["json"], + label_visibility="collapsed", + ) + + # Handle the null case if json_file is None: - st.info("Upload a JSON file of prompts") + st.info("Upload a JSON file containing params and prompts") + with st.expander("Example inputs.json", expanded=False): + st.code(EXAMPLE_INPUT) return + # Read in and print it data = json.loads(json_file.read()) - - with st.expander("Data", expanded=False): + with st.expander("Input Data", expanded=False): st.json(data) params = data["params"] entries = data["entries"] - device = streamlit_util.select_device(st.sidebar) - - show_images = st.sidebar.checkbox("Show Images", True) + show_images = st.sidebar.checkbox("Show Images", False) # Optionally specify an output directory output_dir = st.sidebar.text_input("Output Directory", "") @@ -37,21 +77,20 @@ def render_text_to_audio_batch() -> None: output_path.mkdir(parents=True, exist_ok=True) for i, entry in enumerate(entries): - st.write(f"### Entry {i + 1} / {len(entries)}") - - st.write(f"Prompt: {entry['prompt']}") + st.write(f"#### Entry {i + 1} / {len(entries)}") negative_prompt = entry.get("negative_prompt", None) - st.write(f"Negative prompt: {negative_prompt}") + + st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}") image = streamlit_util.run_txt2img( prompt=entry["prompt"], negative_prompt=negative_prompt, - seed=params["seed"], - num_inference_steps=params["num_inference_steps"], - guidance=params["guidance"], - width=params["width"], - height=params["height"], + seed=params.get("seed", 42), + num_inference_steps=params.get("num_inference_steps", 50), + guidance=params.get("guidance", 7.0), + width=params.get("width", 512), + height=512, device=device, ) @@ -91,6 +130,8 @@ def render_text_to_audio_batch() -> None: output_json_path = output_path / "index.json" output_json_path.write_text(json.dumps(data, indent=4)) st.info(f"Output written to {str(output_path)}") + else: + st.info("Enter output directory in sidebar to save to disk") if __name__ == "__main__": diff --git a/riffusion/streamlit/playground.py b/riffusion/streamlit/playground.py new file mode 100644 index 0000000..d728e5a --- /dev/null +++ b/riffusion/streamlit/playground.py @@ -0,0 +1,37 @@ +import streamlit as st + + +def render_main(): + st.set_page_config(layout="wide", page_icon="🎸") + st.header(":guitar: Riffusion Playground") + st.write("Interactive app for common riffusion tasks.") + + left, right = st.columns(2) + + with left: + create_link(":performing_arts: Interpolation", "/interpolation") + st.write("Interpolate between prompts in the latent space.") + + create_link(":pencil2: Text to Audio", "/text_to_audio") + st.write("Generate audio from text prompts.") + + create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch") + st.write("Generate audio in batch from a JSON file of text prompts.") + + with right: + create_link(":scissors: Sample Clips", "/sample_clips") + st.write("Export short clips from an audio file.") + + create_link(":musical_keyboard: Image to Audio", "/image_to_audio") + st.write("Reconstruct audio from spectrogram images.") + + +def create_link(name: str, url: str) -> None: + st.markdown( + f"### {name}", + unsafe_allow_html=True, + ) + + +if __name__ == "__main__": + render_main() diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index 2d2bb0f..7ebb97c 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -13,6 +13,8 @@ from riffusion.riffusion_pipeline import RiffusionPipeline from riffusion.spectrogram_image_converter import SpectrogramImageConverter from riffusion.spectrogram_params import SpectrogramParams +# TODO(hayk): Add URL params + @st.experimental_singleton def load_riffusion_checkpoint( @@ -53,7 +55,6 @@ def load_stable_diffusion_pipeline( ).to(device) - @st.experimental_memo def run_txt2img( prompt: str, @@ -86,25 +87,6 @@ def run_txt2img( return output["images"][0] -# class CachedSpectrogramImageConverter: - -# def __init__(self, params: SpectrogramParams, device: str = "cuda"): -# self.p = params -# self.device = device -# self.converter = self._converter(params, device) - -# @staticmethod -# @st.experimental_singleton -# def _converter(params: SpectrogramParams, device: str) -> SpectrogramImageConverter: -# return SpectrogramImageConverter(params=params, device=device) - -# def audio_from_spectrogram_image( -# self, -# image: Image.Image -# ) -> pydub.AudioSegment: -# return self._converter.audio_from_spectrogram_image(image) - - @st.experimental_singleton def spectrogram_image_converter( params: SpectrogramParams, @@ -147,13 +129,3 @@ def select_device(container: T.Any = st.sidebar) -> str: assert device is not None return device - - -# @st.experimental_memo -# def spectrogram_image_from_audio( -# segment: pydub.AudioSegment, -# params: SpectrogramParams, -# device: str = "cuda", -# ) -> pydub.AudioSegment: -# converter = spectrogram_image_converter(params=params, device=device) -# return converter.spectrogram_image_from_audio(segment)