Improve interpolation playground

Topic: playground_improvements
This commit is contained in:
Hayk Martiros 2022-12-27 15:22:02 +00:00
parent f7288f8cd3
commit 0610a45e80
1 changed files with 35 additions and 13 deletions

View File

@ -48,11 +48,19 @@ def render_interpolation_demo() -> None:
init_image_name = st.sidebar.selectbox(
"Seed image",
# TODO(hayk): Read from directory
options=["og_beat", "agile", "marim", "motorway", "vibes"],
options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
index=0,
help="Which seed image to use for img2img",
)
assert init_image_name is not None
if init_image_name == "custom":
init_image_file = st.sidebar.file_uploader(
"Upload a custom seed image",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed",
)
if init_image_file:
st.sidebar.image(init_image_file)
show_individual_outputs = st.sidebar.checkbox(
"Show individual outputs",
@ -67,28 +75,38 @@ def render_interpolation_demo() -> None:
# Prompt inputs A and B in two columns
left, right = st.columns(2)
with st.form(key="interpolation_form"):
left, right = st.columns(2)
with left.expander("Input A", expanded=True):
prompt_input_a = get_prompt_inputs(key="a")
with left:
st.write("##### Prompt A")
prompt_input_a = get_prompt_inputs(key="a")
with right.expander("Input B", expanded=True):
prompt_input_b = get_prompt_inputs(key="b")
with right:
st.write("##### Prompt B")
prompt_input_b = get_prompt_inputs(key="b")
st.form_submit_button("Generate", type="primary")
if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them")
return
# TODO(hayk): Make not linspace
alphas = list(np.linspace(0, 1, num_interpolation_steps))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
# TODO(hayk): Upload your own seed image.
init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
)
init_image = Image.open(str(init_image_path)).convert("RGB")
if init_image_name == "custom":
if not init_image_file:
st.info("Upload a custom seed image")
return
init_image = Image.open(init_image_file).convert("RGB")
else:
init_image_path = (
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
)
init_image = Image.open(str(init_image_path)).convert("RGB")
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
image_list: T.List[Image.Image] = []
@ -168,7 +186,11 @@ def run_interpolation(
"""
Cached function for riffusion interpolation.
"""
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
pipeline = streamlit_util.load_riffusion_checkpoint(
device=device,
# No trace so we can have variable width
no_traced_unet=True,
)
image = pipeline.riffuse(
inputs,