riffusion-inference/riffusion/streamlit/tasks/interpolation.py

281 lines
8.5 KiB
Python

import dataclasses
import io
import typing as T
from pathlib import Path
import numpy as np
import pydub
import streamlit as st
from PIL import Image
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
def render() -> None:
st.subheader("🎭 Interpolation")
st.write(
"""
Interpolate between prompts in the latent space.
"""
)
with st.expander("Help", False):
st.write(
"""
This tool allows specifying two endpoints and generating a long-form interpolation
between them that traverses the latent space. The interpolation is generated by
the method described at https://www.riffusion.com/about. A seed image is used to
set the beat and tempo of the generated audio, and can be set in the sidebar.
Usually the seed is changed or the prompt, but not both at once. You can browse
infinite variations of the same prompt by changing the seed.
For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising.
This will generate a 50 second clip at 5 seconds per step. Then play with the seeds
or denoising to get different variations.
"""
)
# Sidebar params
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
num_interpolation_steps = T.cast(
int,
st.sidebar.number_input(
"Interpolation steps",
value=12,
min_value=1,
max_value=20,
help="Number of model generations between the two prompts. Controls the duration.",
),
)
num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
guidance = st.sidebar.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)
init_image_name = st.sidebar.selectbox(
"Seed image",
# TODO(hayk): Read from directory
options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
index=0,
help="Which seed image to use for img2img. Custom allows uploading your own.",
)
assert init_image_name is not None
if init_image_name == "custom":
init_image_file = st.sidebar.file_uploader(
"Upload a custom seed image",
type=streamlit_util.IMAGE_EXTENSIONS,
label_visibility="collapsed",
)
if init_image_file:
st.sidebar.image(init_image_file)
alpha_power = st.sidebar.number_input("Alpha Power", value=1.0)
show_individual_outputs = st.sidebar.checkbox(
"Show individual outputs",
value=False,
help="Show each model output",
)
show_images = st.sidebar.checkbox(
"Show individual images",
value=False,
help="Show each generated image",
)
alphas = np.linspace(0, 1, num_interpolation_steps)
# Apply power scaling to alphas to customize the interpolation curve
alphas_shifted = alphas * 2 - 1
alphas_shifted = (np.abs(alphas_shifted) ** alpha_power * np.sign(alphas_shifted) + 1) / 2
alphas = alphas_shifted
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
# Prompt inputs A and B in two columns
with st.form(key="interpolation_form"):
left, right = st.columns(2)
with left:
st.write("##### Prompt A")
prompt_input_a = PromptInput(
guidance=guidance, **get_prompt_inputs(key="a", denoising_default=0.75)
)
with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(
guidance=guidance, **get_prompt_inputs(key="b", denoising_default=0.75)
)
st.form_submit_button("Generate", type="primary")
if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them")
return
if init_image_name == "custom":
if not init_image_file:
st.info("Upload a custom seed image")
return
init_image = Image.open(init_image_file).convert("RGB")
else:
init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
)
init_image = Image.open(str(init_image_path)).convert("RGB")
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
image_list: T.List[Image.Image] = []
audio_bytes_list: T.List[io.BytesIO] = []
for i, alpha in enumerate(alphas):
inputs = InferenceInput(
alpha=float(alpha),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)
if i == 0:
with st.expander("Example input JSON", expanded=False):
st.json(dataclasses.asdict(inputs))
image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image,
device=device,
extension=extension,
)
if show_individual_outputs:
st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
if show_images:
st.image(image)
st.audio(audio_bytes)
image_list.append(image)
audio_bytes_list.append(audio_bytes)
st.write("#### Final Output")
# TODO(hayk): Concatenate with overlap and better blending like in audio to audio
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
concat_segment = audio_segments[0]
for segment in audio_segments[1:]:
concat_segment = concat_segment.append(segment, crossfade=0)
audio_bytes = io.BytesIO()
concat_segment.export(audio_bytes, format=extension)
audio_bytes.seek(0)
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
st.audio(audio_bytes)
output_name = (
f"{prompt_input_a.prompt.replace(' ', '_')}_"
f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
)
st.download_button(
output_name,
data=audio_bytes,
file_name=output_name,
mime=f"audio/{extension}",
)
def get_prompt_inputs(
key: str,
include_negative_prompt: bool = False,
cols: bool = False,
denoising_default: float = 0.5,
) -> T.Dict[str, T.Any]:
"""
Compute prompt inputs from widgets.
"""
p: T.Dict[str, T.Any] = {}
# Optionally use columns
left, right = T.cast(T.Any, st.columns(2) if cols else (st, st))
visibility = "visible" if cols else "collapsed"
p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}")
if include_negative_prompt:
p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}")
p["seed"] = T.cast(
int,
left.number_input(
"Seed",
value=42,
key=f"seed_{key}",
help="Integer used to generate a random result. Vary this to explore alternatives.",
),
)
p["denoising"] = right.number_input(
"Denoising",
value=denoising_default,
key=f"denoising_{key}",
help="How much to modify the seed image",
)
return p
@st.cache_data
def run_interpolation(
inputs: InferenceInput,
init_image: Image.Image,
checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
device: str = "cuda",
extension: str = "mp3",
) -> T.Tuple[Image.Image, io.BytesIO]:
"""
Cached function for riffusion interpolation.
"""
pipeline = streamlit_util.load_riffusion_checkpoint(
device=device,
checkpoint=checkpoint,
# No trace so we can have variable width
no_traced_unet=True,
)
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=None,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct from image to audio
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format=extension,
)
return image, audio_bytes