Streamlit fixes and batch text to audio with multiple models
Topic: batch_text_with_multiple_models
This commit is contained in:
parent
8d44cbbc64
commit
84c5847eff
|
@ -43,7 +43,9 @@ def render() -> None:
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
|
checkpoint = streamlit_util.select_checkpoint(st.sidebar)
|
||||||
|
|
||||||
|
use_20k = st.sidebar.checkbox("Use 20kHz", value=False)
|
||||||
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
|
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
@ -183,7 +185,19 @@ def render() -> None:
|
||||||
|
|
||||||
st.write(f"## Counter: {counter.value}")
|
st.write(f"## Counter: {counter.value}")
|
||||||
|
|
||||||
params = SpectrogramParams()
|
if use_20k:
|
||||||
|
params = SpectrogramParams(
|
||||||
|
min_frequency=10,
|
||||||
|
max_frequency=20000,
|
||||||
|
sample_rate=44100,
|
||||||
|
stereo=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params = SpectrogramParams(
|
||||||
|
min_frequency=0,
|
||||||
|
max_frequency=10000,
|
||||||
|
stereo=False,
|
||||||
|
)
|
||||||
|
|
||||||
if interpolate:
|
if interpolate:
|
||||||
# TODO(hayk): Make not linspace
|
# TODO(hayk): Make not linspace
|
||||||
|
@ -237,6 +251,7 @@ def render() -> None:
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
init_image=init_image_resized,
|
init_image=init_image_resized,
|
||||||
device=device,
|
device=device,
|
||||||
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
elif use_magic_mix:
|
elif use_magic_mix:
|
||||||
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
|
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
|
||||||
|
@ -251,6 +266,7 @@ def render() -> None:
|
||||||
mix_factor=magic_mix_mix_factor,
|
mix_factor=magic_mix_mix_factor,
|
||||||
device=device,
|
device=device,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = streamlit_util.run_img2img(
|
image = streamlit_util.run_img2img(
|
||||||
|
@ -264,6 +280,7 @@ def render() -> None:
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
device=device,
|
device=device,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resize back to original size
|
# Resize back to original size
|
||||||
|
|
|
@ -241,13 +241,18 @@ def get_prompt_inputs(
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def run_interpolation(
|
def run_interpolation(
|
||||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
inputs: InferenceInput,
|
||||||
|
init_image: Image.Image,
|
||||||
|
checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
|
||||||
|
device: str = "cuda",
|
||||||
|
extension: str = "mp3",
|
||||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||||
"""
|
"""
|
||||||
Cached function for riffusion interpolation.
|
Cached function for riffusion interpolation.
|
||||||
"""
|
"""
|
||||||
pipeline = streamlit_util.load_riffusion_checkpoint(
|
pipeline = streamlit_util.load_riffusion_checkpoint(
|
||||||
device=device,
|
device=device,
|
||||||
|
checkpoint=checkpoint,
|
||||||
# No trace so we can have variable width
|
# No trace so we can have variable width
|
||||||
no_traced_unet=True,
|
no_traced_unet=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -55,7 +55,7 @@ def render() -> None:
|
||||||
st.form_submit_button("Riff", type="primary")
|
st.form_submit_button("Riff", type="primary")
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=25))
|
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30))
|
||||||
width = T.cast(int, st.number_input("Width", value=512))
|
width = T.cast(int, st.number_input("Width", value=512))
|
||||||
guidance = st.number_input(
|
guidance = st.number_input(
|
||||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||||
|
|
|
@ -11,21 +11,25 @@ from riffusion.streamlit import util as streamlit_util
|
||||||
EXAMPLE_INPUT = """
|
EXAMPLE_INPUT = """
|
||||||
{
|
{
|
||||||
"params": {
|
"params": {
|
||||||
"seed": 42,
|
"checkpoint": "riffusion/riffusion-model-v1",
|
||||||
|
"scheduler": "DPMSolverMultistepScheduler",
|
||||||
"num_inference_steps": 50,
|
"num_inference_steps": 50,
|
||||||
"guidance": 7.0,
|
"guidance": 7.0,
|
||||||
"width": 512
|
"width": 512,
|
||||||
},
|
},
|
||||||
"entries": [
|
"entries": [
|
||||||
{
|
{
|
||||||
"prompt": "Church bells"
|
"prompt": "Church bells",
|
||||||
|
"seed": 42
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prompt": "electronic beats",
|
"prompt": "electronic beats",
|
||||||
"negative_prompt": "drums"
|
"negative_prompt": "drums",
|
||||||
|
"seed": 100
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prompt": "classical violin concerto"
|
"prompt": "classical violin concerto",
|
||||||
|
"seed": 4
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -71,10 +75,22 @@ def render() -> None:
|
||||||
with st.expander("Input Data", expanded=False):
|
with st.expander("Input Data", expanded=False):
|
||||||
st.json(data)
|
st.json(data)
|
||||||
|
|
||||||
params = data["params"]
|
# Params can either be a list or a single entry
|
||||||
|
if isinstance(data["params"], list):
|
||||||
|
param_sets = data["params"]
|
||||||
|
else:
|
||||||
|
param_sets = [data["params"]]
|
||||||
|
|
||||||
entries = data["entries"]
|
entries = data["entries"]
|
||||||
|
|
||||||
show_images = st.sidebar.checkbox("Show Images", False)
|
show_images = st.sidebar.checkbox("Show Images", True)
|
||||||
|
num_seeds = st.sidebar.number_input(
|
||||||
|
"Num Seeds",
|
||||||
|
value=1,
|
||||||
|
min_value=1,
|
||||||
|
max_value=10,
|
||||||
|
help="When > 1, increments the seed and runs multiple for each entry",
|
||||||
|
)
|
||||||
|
|
||||||
# Optionally specify an output directory
|
# Optionally specify an output directory
|
||||||
output_dir = st.sidebar.text_input("Output Directory", "")
|
output_dir = st.sidebar.text_input("Output Directory", "")
|
||||||
|
@ -83,26 +99,51 @@ def render() -> None:
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for i, entry in enumerate(entries):
|
# Write title cards for each param set
|
||||||
st.write(f"#### Entry {i + 1} / {len(entries)}")
|
title_cols = st.columns(len(param_sets))
|
||||||
|
for i, params in enumerate(param_sets):
|
||||||
|
col = title_cols[i]
|
||||||
|
|
||||||
|
if "name" not in params:
|
||||||
|
params["name"] = f"params[{i}]"
|
||||||
|
|
||||||
|
col.write(f"## Param Set {i}")
|
||||||
|
col.json(params)
|
||||||
|
|
||||||
|
for entry_i, entry in enumerate(entries):
|
||||||
|
st.write("---")
|
||||||
|
print(entry)
|
||||||
|
prompt = entry["prompt"]
|
||||||
negative_prompt = entry.get("negative_prompt", None)
|
negative_prompt = entry.get("negative_prompt", None)
|
||||||
|
|
||||||
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
|
base_seed = entry.get("seed", 42)
|
||||||
|
|
||||||
|
text = f"##### ({base_seed}) {prompt}"
|
||||||
|
if negative_prompt:
|
||||||
|
text += f" \n**Negative**: {negative_prompt}"
|
||||||
|
st.write(text)
|
||||||
|
|
||||||
|
for seed in range(base_seed, base_seed + num_seeds):
|
||||||
|
cols = st.columns(len(param_sets))
|
||||||
|
for i, params in enumerate(param_sets):
|
||||||
|
col = cols[i]
|
||||||
|
col.write(params["name"])
|
||||||
|
|
||||||
image = streamlit_util.run_txt2img(
|
image = streamlit_util.run_txt2img(
|
||||||
prompt=entry["prompt"],
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=params.get("seed", 42),
|
seed=seed,
|
||||||
num_inference_steps=params.get("num_inference_steps", 50),
|
num_inference_steps=params.get("num_inference_steps", 50),
|
||||||
guidance=params.get("guidance", 7.0),
|
guidance=params.get("guidance", 7.0),
|
||||||
width=params.get("width", 512),
|
width=params.get("width", 512),
|
||||||
|
checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT),
|
||||||
|
scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]),
|
||||||
height=512,
|
height=512,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if show_images:
|
if show_images:
|
||||||
st.image(image)
|
col.image(image)
|
||||||
|
|
||||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||||
p_spectrogram = SpectrogramParams(
|
p_spectrogram = SpectrogramParams(
|
||||||
|
@ -117,18 +158,21 @@ def render() -> None:
|
||||||
device=device,
|
device=device,
|
||||||
output_format=output_format,
|
output_format=output_format,
|
||||||
)
|
)
|
||||||
st.audio(audio_bytes)
|
col.audio(audio_bytes)
|
||||||
|
|
||||||
if output_path:
|
if output_path:
|
||||||
prompt_slug = entry["prompt"].replace(" ", "_")
|
prompt_slug = entry["prompt"].replace(" ", "_")
|
||||||
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
|
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
|
||||||
|
|
||||||
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
image_path = (
|
||||||
|
output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
||||||
|
)
|
||||||
image.save(image_path, format="JPEG")
|
image.save(image_path, format="JPEG")
|
||||||
entry["image_path"] = str(image_path)
|
entry["image_path"] = str(image_path)
|
||||||
|
|
||||||
audio_path = (
|
audio_path = (
|
||||||
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
output_path
|
||||||
|
/ f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
||||||
)
|
)
|
||||||
audio_path.write_bytes(audio_bytes.getbuffer())
|
audio_path.write_bytes(audio_bytes.getbuffer())
|
||||||
entry["audio_path"] = str(audio_path)
|
entry["audio_path"] = str(audio_path)
|
||||||
|
|
|
@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data(persist=True)
|
||||||
def run_txt2img(
|
def run_txt2img(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
@ -281,12 +281,11 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
|
||||||
"""
|
"""
|
||||||
Provide a custom model checkpoint.
|
Provide a custom model checkpoint.
|
||||||
"""
|
"""
|
||||||
custom_checkpoint = container.text_input(
|
return container.text_input(
|
||||||
"Custom Checkpoint",
|
"Custom Checkpoint",
|
||||||
value="",
|
value=DEFAULT_CHECKPOINT,
|
||||||
help="Provide a custom model checkpoint",
|
help="Provide a custom model checkpoint",
|
||||||
)
|
)
|
||||||
return custom_checkpoint or DEFAULT_CHECKPOINT
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
|
|
Loading…
Reference in New Issue