Improve interpolation playground

Topic: playground_improvements
This commit is contained in:
Hayk Martiros 2022-12-27 15:22:02 +00:00
parent f7288f8cd3
commit 0610a45e80
1 changed files with 35 additions and 13 deletions

View File

@ -48,11 +48,19 @@ def render_interpolation_demo() -> None:
init_image_name = st.sidebar.selectbox( init_image_name = st.sidebar.selectbox(
"Seed image", "Seed image",
# TODO(hayk): Read from directory # TODO(hayk): Read from directory
options=["og_beat", "agile", "marim", "motorway", "vibes"], options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
index=0, index=0,
help="Which seed image to use for img2img", help="Which seed image to use for img2img",
) )
assert init_image_name is not None assert init_image_name is not None
if init_image_name == "custom":
init_image_file = st.sidebar.file_uploader(
"Upload a custom seed image",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed",
)
if init_image_file:
st.sidebar.image(init_image_file)
show_individual_outputs = st.sidebar.checkbox( show_individual_outputs = st.sidebar.checkbox(
"Show individual outputs", "Show individual outputs",
@ -67,24 +75,34 @@ def render_interpolation_demo() -> None:
# Prompt inputs A and B in two columns # Prompt inputs A and B in two columns
with st.form(key="interpolation_form"):
left, right = st.columns(2) left, right = st.columns(2)
with left.expander("Input A", expanded=True): with left:
st.write("##### Prompt A")
prompt_input_a = get_prompt_inputs(key="a") prompt_input_a = get_prompt_inputs(key="a")
with right.expander("Input B", expanded=True): with right:
st.write("##### Prompt B")
prompt_input_b = get_prompt_inputs(key="b") prompt_input_b = get_prompt_inputs(key="b")
st.form_submit_button("Generate", type="primary")
if not prompt_input_a.prompt or not prompt_input_b.prompt: if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them") st.info("Enter both prompts to interpolate between them")
return return
# TODO(hayk): Make not linspace
alphas = list(np.linspace(0, 1, num_interpolation_steps)) alphas = list(np.linspace(0, 1, num_interpolation_steps))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]") st.write(f"**Alphas** : [{alphas_str}]")
# TODO(hayk): Upload your own seed image. if init_image_name == "custom":
if not init_image_file:
st.info("Upload a custom seed image")
return
init_image = Image.open(init_image_file).convert("RGB")
else:
init_image_path = ( init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png" Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
) )
@ -168,7 +186,11 @@ def run_interpolation(
""" """
Cached function for riffusion interpolation. Cached function for riffusion interpolation.
""" """
pipeline = streamlit_util.load_riffusion_checkpoint(device=device) pipeline = streamlit_util.load_riffusion_checkpoint(
device=device,
# No trace so we can have variable width
no_traced_unet=True,
)
image = pipeline.riffuse( image = pipeline.riffuse(
inputs, inputs,