Improve interpolation playground
Topic: playground_improvements
This commit is contained in:
parent
f7288f8cd3
commit
0610a45e80
|
@ -48,11 +48,19 @@ def render_interpolation_demo() -> None:
|
||||||
init_image_name = st.sidebar.selectbox(
|
init_image_name = st.sidebar.selectbox(
|
||||||
"Seed image",
|
"Seed image",
|
||||||
# TODO(hayk): Read from directory
|
# TODO(hayk): Read from directory
|
||||||
options=["og_beat", "agile", "marim", "motorway", "vibes"],
|
options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
|
||||||
index=0,
|
index=0,
|
||||||
help="Which seed image to use for img2img",
|
help="Which seed image to use for img2img",
|
||||||
)
|
)
|
||||||
assert init_image_name is not None
|
assert init_image_name is not None
|
||||||
|
if init_image_name == "custom":
|
||||||
|
init_image_file = st.sidebar.file_uploader(
|
||||||
|
"Upload a custom seed image",
|
||||||
|
type=["png", "jpg", "jpeg"],
|
||||||
|
label_visibility="collapsed",
|
||||||
|
)
|
||||||
|
if init_image_file:
|
||||||
|
st.sidebar.image(init_image_file)
|
||||||
|
|
||||||
show_individual_outputs = st.sidebar.checkbox(
|
show_individual_outputs = st.sidebar.checkbox(
|
||||||
"Show individual outputs",
|
"Show individual outputs",
|
||||||
|
@ -67,28 +75,38 @@ def render_interpolation_demo() -> None:
|
||||||
|
|
||||||
# Prompt inputs A and B in two columns
|
# Prompt inputs A and B in two columns
|
||||||
|
|
||||||
left, right = st.columns(2)
|
with st.form(key="interpolation_form"):
|
||||||
|
left, right = st.columns(2)
|
||||||
|
|
||||||
with left.expander("Input A", expanded=True):
|
with left:
|
||||||
prompt_input_a = get_prompt_inputs(key="a")
|
st.write("##### Prompt A")
|
||||||
|
prompt_input_a = get_prompt_inputs(key="a")
|
||||||
|
|
||||||
with right.expander("Input B", expanded=True):
|
with right:
|
||||||
prompt_input_b = get_prompt_inputs(key="b")
|
st.write("##### Prompt B")
|
||||||
|
prompt_input_b = get_prompt_inputs(key="b")
|
||||||
|
|
||||||
|
st.form_submit_button("Generate", type="primary")
|
||||||
|
|
||||||
if not prompt_input_a.prompt or not prompt_input_b.prompt:
|
if not prompt_input_a.prompt or not prompt_input_b.prompt:
|
||||||
st.info("Enter both prompts to interpolate between them")
|
st.info("Enter both prompts to interpolate between them")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# TODO(hayk): Make not linspace
|
||||||
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
||||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||||
st.write(f"**Alphas** : [{alphas_str}]")
|
st.write(f"**Alphas** : [{alphas_str}]")
|
||||||
|
|
||||||
# TODO(hayk): Upload your own seed image.
|
if init_image_name == "custom":
|
||||||
|
if not init_image_file:
|
||||||
init_image_path = (
|
st.info("Upload a custom seed image")
|
||||||
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
|
return
|
||||||
)
|
init_image = Image.open(init_image_file).convert("RGB")
|
||||||
init_image = Image.open(str(init_image_path)).convert("RGB")
|
else:
|
||||||
|
init_image_path = (
|
||||||
|
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
|
||||||
|
)
|
||||||
|
init_image = Image.open(str(init_image_path)).convert("RGB")
|
||||||
|
|
||||||
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
|
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
|
||||||
image_list: T.List[Image.Image] = []
|
image_list: T.List[Image.Image] = []
|
||||||
|
@ -168,7 +186,11 @@ def run_interpolation(
|
||||||
"""
|
"""
|
||||||
Cached function for riffusion interpolation.
|
Cached function for riffusion interpolation.
|
||||||
"""
|
"""
|
||||||
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
|
pipeline = streamlit_util.load_riffusion_checkpoint(
|
||||||
|
device=device,
|
||||||
|
# No trace so we can have variable width
|
||||||
|
no_traced_unet=True,
|
||||||
|
)
|
||||||
|
|
||||||
image = pipeline.riffuse(
|
image = pipeline.riffuse(
|
||||||
inputs,
|
inputs,
|
||||||
|
|
Loading…
Reference in New Issue