From a861b48232ffb2e5a1cb358a68141d8af2a2d93f Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Tue, 11 Apr 2023 04:26:30 +0000 Subject: [PATCH] Streamlit fixes and batch text to audio with multiple models Topic: batch_text_with_multiple_models --- riffusion/streamlit/tasks/audio_to_audio.py | 19 ++- riffusion/streamlit/tasks/interpolation.py | 7 +- riffusion/streamlit/tasks/text_to_audio.py | 2 +- .../streamlit/tasks/text_to_audio_batch.py | 136 ++++++++++++------ riffusion/streamlit/util.py | 7 +- 5 files changed, 118 insertions(+), 53 deletions(-) diff --git a/riffusion/streamlit/tasks/audio_to_audio.py b/riffusion/streamlit/tasks/audio_to_audio.py index ee866c0..3ae8e68 100644 --- a/riffusion/streamlit/tasks/audio_to_audio.py +++ b/riffusion/streamlit/tasks/audio_to_audio.py @@ -43,7 +43,9 @@ def render() -> None: device = streamlit_util.select_device(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar) + checkpoint = streamlit_util.select_checkpoint(st.sidebar) + use_20k = st.sidebar.checkbox("Use 20kHz", value=False) use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False) with st.sidebar: @@ -183,7 +185,19 @@ def render() -> None: st.write(f"## Counter: {counter.value}") - params = SpectrogramParams() + if use_20k: + params = SpectrogramParams( + min_frequency=10, + max_frequency=20000, + sample_rate=44100, + stereo=True, + ) + else: + params = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + stereo=False, + ) if interpolate: # TODO(hayk): Make not linspace @@ -237,6 +251,7 @@ def render() -> None: inputs=inputs, init_image=init_image_resized, device=device, + checkpoint=checkpoint, ) elif use_magic_mix: assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix" @@ -251,6 +266,7 @@ def render() -> None: mix_factor=magic_mix_mix_factor, device=device, scheduler=scheduler, + checkpoint=checkpoint, ) else: image = streamlit_util.run_img2img( @@ -264,6 +280,7 @@ def render() -> None: progress_callback=progress_callback, device=device, scheduler=scheduler, + checkpoint=checkpoint, ) # Resize back to original size diff --git a/riffusion/streamlit/tasks/interpolation.py b/riffusion/streamlit/tasks/interpolation.py index 2bb708b..ea55db1 100644 --- a/riffusion/streamlit/tasks/interpolation.py +++ b/riffusion/streamlit/tasks/interpolation.py @@ -241,13 +241,18 @@ def get_prompt_inputs( @st.cache_data def run_interpolation( - inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3" + inputs: InferenceInput, + init_image: Image.Image, + checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT, + device: str = "cuda", + extension: str = "mp3", ) -> T.Tuple[Image.Image, io.BytesIO]: """ Cached function for riffusion interpolation. """ pipeline = streamlit_util.load_riffusion_checkpoint( device=device, + checkpoint=checkpoint, # No trace so we can have variable width no_traced_unet=True, ) diff --git a/riffusion/streamlit/tasks/text_to_audio.py b/riffusion/streamlit/tasks/text_to_audio.py index c70e431..2549f31 100644 --- a/riffusion/streamlit/tasks/text_to_audio.py +++ b/riffusion/streamlit/tasks/text_to_audio.py @@ -55,7 +55,7 @@ def render() -> None: st.form_submit_button("Riff", type="primary") with st.sidebar: - num_inference_steps = T.cast(int, st.number_input("Inference steps", value=25)) + num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30)) width = T.cast(int, st.number_input("Width", value=512)) guidance = st.number_input( "Guidance", value=7.0, help="How much the model listens to the text prompt" diff --git a/riffusion/streamlit/tasks/text_to_audio_batch.py b/riffusion/streamlit/tasks/text_to_audio_batch.py index fa2fa42..abaff50 100644 --- a/riffusion/streamlit/tasks/text_to_audio_batch.py +++ b/riffusion/streamlit/tasks/text_to_audio_batch.py @@ -11,21 +11,25 @@ from riffusion.streamlit import util as streamlit_util EXAMPLE_INPUT = """ { "params": { - "seed": 42, + "checkpoint": "riffusion/riffusion-model-v1", + "scheduler": "DPMSolverMultistepScheduler", "num_inference_steps": 50, "guidance": 7.0, - "width": 512 + "width": 512, }, "entries": [ { - "prompt": "Church bells" + "prompt": "Church bells", + "seed": 42 }, { "prompt": "electronic beats", - "negative_prompt": "drums" + "negative_prompt": "drums", + "seed": 100 }, { - "prompt": "classical violin concerto" + "prompt": "classical violin concerto", + "seed": 4 } ] } @@ -71,10 +75,22 @@ def render() -> None: with st.expander("Input Data", expanded=False): st.json(data) - params = data["params"] + # Params can either be a list or a single entry + if isinstance(data["params"], list): + param_sets = data["params"] + else: + param_sets = [data["params"]] + entries = data["entries"] - show_images = st.sidebar.checkbox("Show Images", False) + show_images = st.sidebar.checkbox("Show Images", True) + num_seeds = st.sidebar.number_input( + "Num Seeds", + value=1, + min_value=1, + max_value=10, + help="When > 1, increments the seed and runs multiple for each entry", + ) # Optionally specify an output directory output_dir = st.sidebar.text_input("Output Directory", "") @@ -83,55 +99,83 @@ def render() -> None: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) - for i, entry in enumerate(entries): - st.write(f"#### Entry {i + 1} / {len(entries)}") + # Write title cards for each param set + title_cols = st.columns(len(param_sets)) + for i, params in enumerate(param_sets): + col = title_cols[i] + if "name" not in params: + params["name"] = f"params[{i}]" + + col.write(f"## Param Set {i}") + col.json(params) + + for entry_i, entry in enumerate(entries): + st.write("---") + print(entry) + prompt = entry["prompt"] negative_prompt = entry.get("negative_prompt", None) - st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}") + base_seed = entry.get("seed", 42) - image = streamlit_util.run_txt2img( - prompt=entry["prompt"], - negative_prompt=negative_prompt, - seed=params.get("seed", 42), - num_inference_steps=params.get("num_inference_steps", 50), - guidance=params.get("guidance", 7.0), - width=params.get("width", 512), - height=512, - device=device, - ) + text = f"##### ({base_seed}) {prompt}" + if negative_prompt: + text += f" \n**Negative**: {negative_prompt}" + st.write(text) - if show_images: - st.image(image) + for seed in range(base_seed, base_seed + num_seeds): + cols = st.columns(len(param_sets)) + for i, params in enumerate(param_sets): + col = cols[i] + col.write(params["name"]) - # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained - p_spectrogram = SpectrogramParams( - min_frequency=0, - max_frequency=10000, - ) + image = streamlit_util.run_txt2img( + prompt=prompt, + negative_prompt=negative_prompt, + seed=seed, + num_inference_steps=params.get("num_inference_steps", 50), + guidance=params.get("guidance", 7.0), + width=params.get("width", 512), + checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT), + scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]), + height=512, + device=device, + ) - output_format = "mp3" - audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( - image=image, - params=p_spectrogram, - device=device, - output_format=output_format, - ) - st.audio(audio_bytes) + if show_images: + col.image(image) - if output_path: - prompt_slug = entry["prompt"].replace(" ", "_") - negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_") + # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained + p_spectrogram = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + ) - image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg" - image.save(image_path, format="JPEG") - entry["image_path"] = str(image_path) + output_format = "mp3" + audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + image=image, + params=p_spectrogram, + device=device, + output_format=output_format, + ) + col.audio(audio_bytes) - audio_path = ( - output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}" - ) - audio_path.write_bytes(audio_bytes.getbuffer()) - entry["audio_path"] = str(audio_path) + if output_path: + prompt_slug = entry["prompt"].replace(" ", "_") + negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_") + + image_path = ( + output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg" + ) + image.save(image_path, format="JPEG") + entry["image_path"] = str(image_path) + + audio_path = ( + output_path + / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}" + ) + audio_path.write_bytes(audio_bytes.getbuffer()) + entry["audio_path"] = str(audio_path) if output_path: output_json_path = output_path / "index.json" diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index dfccc68..a332dc1 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline( return pipeline -@st.cache_data +@st.cache_data(persist=True) def run_txt2img( prompt: str, num_inference_steps: int, @@ -281,12 +281,11 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str: """ Provide a custom model checkpoint. """ - custom_checkpoint = container.text_input( + return container.text_input( "Custom Checkpoint", - value="", + value=DEFAULT_CHECKPOINT, help="Provide a custom model checkpoint", ) - return custom_checkpoint or DEFAULT_CHECKPOINT @st.cache_data