Improve interpolation playground
Topic: playground_improvements
This commit is contained in:
parent
f7288f8cd3
commit
0610a45e80
|
@ -48,11 +48,19 @@ def render_interpolation_demo() -> None:
|
|||
init_image_name = st.sidebar.selectbox(
|
||||
"Seed image",
|
||||
# TODO(hayk): Read from directory
|
||||
options=["og_beat", "agile", "marim", "motorway", "vibes"],
|
||||
options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
|
||||
index=0,
|
||||
help="Which seed image to use for img2img",
|
||||
)
|
||||
assert init_image_name is not None
|
||||
if init_image_name == "custom":
|
||||
init_image_file = st.sidebar.file_uploader(
|
||||
"Upload a custom seed image",
|
||||
type=["png", "jpg", "jpeg"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if init_image_file:
|
||||
st.sidebar.image(init_image_file)
|
||||
|
||||
show_individual_outputs = st.sidebar.checkbox(
|
||||
"Show individual outputs",
|
||||
|
@ -67,28 +75,38 @@ def render_interpolation_demo() -> None:
|
|||
|
||||
# Prompt inputs A and B in two columns
|
||||
|
||||
left, right = st.columns(2)
|
||||
with st.form(key="interpolation_form"):
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left.expander("Input A", expanded=True):
|
||||
prompt_input_a = get_prompt_inputs(key="a")
|
||||
with left:
|
||||
st.write("##### Prompt A")
|
||||
prompt_input_a = get_prompt_inputs(key="a")
|
||||
|
||||
with right.expander("Input B", expanded=True):
|
||||
prompt_input_b = get_prompt_inputs(key="b")
|
||||
with right:
|
||||
st.write("##### Prompt B")
|
||||
prompt_input_b = get_prompt_inputs(key="b")
|
||||
|
||||
st.form_submit_button("Generate", type="primary")
|
||||
|
||||
if not prompt_input_a.prompt or not prompt_input_b.prompt:
|
||||
st.info("Enter both prompts to interpolate between them")
|
||||
return
|
||||
|
||||
# TODO(hayk): Make not linspace
|
||||
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||
st.write(f"**Alphas** : [{alphas_str}]")
|
||||
|
||||
# TODO(hayk): Upload your own seed image.
|
||||
|
||||
init_image_path = (
|
||||
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
|
||||
)
|
||||
init_image = Image.open(str(init_image_path)).convert("RGB")
|
||||
if init_image_name == "custom":
|
||||
if not init_image_file:
|
||||
st.info("Upload a custom seed image")
|
||||
return
|
||||
init_image = Image.open(init_image_file).convert("RGB")
|
||||
else:
|
||||
init_image_path = (
|
||||
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
|
||||
)
|
||||
init_image = Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
|
||||
image_list: T.List[Image.Image] = []
|
||||
|
@ -168,7 +186,11 @@ def run_interpolation(
|
|||
"""
|
||||
Cached function for riffusion interpolation.
|
||||
"""
|
||||
pipeline = streamlit_util.load_riffusion_checkpoint(device=device)
|
||||
pipeline = streamlit_util.load_riffusion_checkpoint(
|
||||
device=device,
|
||||
# No trace so we can have variable width
|
||||
no_traced_unet=True,
|
||||
)
|
||||
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
|
|
Loading…
Reference in New Issue