Streamlit fixes and batch text to audio with multiple models
Topic: batch_text_with_multiple_models
This commit is contained in:
parent
8d44cbbc64
commit
a861b48232
|
@ -43,7 +43,9 @@ def render() -> None:
|
|||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
checkpoint = streamlit_util.select_checkpoint(st.sidebar)
|
||||
|
||||
use_20k = st.sidebar.checkbox("Use 20kHz", value=False)
|
||||
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
|
||||
|
||||
with st.sidebar:
|
||||
|
@ -183,7 +185,19 @@ def render() -> None:
|
|||
|
||||
st.write(f"## Counter: {counter.value}")
|
||||
|
||||
params = SpectrogramParams()
|
||||
if use_20k:
|
||||
params = SpectrogramParams(
|
||||
min_frequency=10,
|
||||
max_frequency=20000,
|
||||
sample_rate=44100,
|
||||
stereo=True,
|
||||
)
|
||||
else:
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
stereo=False,
|
||||
)
|
||||
|
||||
if interpolate:
|
||||
# TODO(hayk): Make not linspace
|
||||
|
@ -237,6 +251,7 @@ def render() -> None:
|
|||
inputs=inputs,
|
||||
init_image=init_image_resized,
|
||||
device=device,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
elif use_magic_mix:
|
||||
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
|
||||
|
@ -251,6 +266,7 @@ def render() -> None:
|
|||
mix_factor=magic_mix_mix_factor,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
else:
|
||||
image = streamlit_util.run_img2img(
|
||||
|
@ -264,6 +280,7 @@ def render() -> None:
|
|||
progress_callback=progress_callback,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# Resize back to original size
|
||||
|
|
|
@ -241,13 +241,18 @@ def get_prompt_inputs(
|
|||
|
||||
@st.cache_data
|
||||
def run_interpolation(
|
||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
||||
inputs: InferenceInput,
|
||||
init_image: Image.Image,
|
||||
checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
|
||||
device: str = "cuda",
|
||||
extension: str = "mp3",
|
||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||
"""
|
||||
Cached function for riffusion interpolation.
|
||||
"""
|
||||
pipeline = streamlit_util.load_riffusion_checkpoint(
|
||||
device=device,
|
||||
checkpoint=checkpoint,
|
||||
# No trace so we can have variable width
|
||||
no_traced_unet=True,
|
||||
)
|
||||
|
|
|
@ -55,7 +55,7 @@ def render() -> None:
|
|||
st.form_submit_button("Riff", type="primary")
|
||||
|
||||
with st.sidebar:
|
||||
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=25))
|
||||
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30))
|
||||
width = T.cast(int, st.number_input("Width", value=512))
|
||||
guidance = st.number_input(
|
||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||
|
|
|
@ -11,21 +11,25 @@ from riffusion.streamlit import util as streamlit_util
|
|||
EXAMPLE_INPUT = """
|
||||
{
|
||||
"params": {
|
||||
"seed": 42,
|
||||
"checkpoint": "riffusion/riffusion-model-v1",
|
||||
"scheduler": "DPMSolverMultistepScheduler",
|
||||
"num_inference_steps": 50,
|
||||
"guidance": 7.0,
|
||||
"width": 512
|
||||
"width": 512,
|
||||
},
|
||||
"entries": [
|
||||
{
|
||||
"prompt": "Church bells"
|
||||
"prompt": "Church bells",
|
||||
"seed": 42
|
||||
},
|
||||
{
|
||||
"prompt": "electronic beats",
|
||||
"negative_prompt": "drums"
|
||||
"negative_prompt": "drums",
|
||||
"seed": 100
|
||||
},
|
||||
{
|
||||
"prompt": "classical violin concerto"
|
||||
"prompt": "classical violin concerto",
|
||||
"seed": 4
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -71,10 +75,22 @@ def render() -> None:
|
|||
with st.expander("Input Data", expanded=False):
|
||||
st.json(data)
|
||||
|
||||
params = data["params"]
|
||||
# Params can either be a list or a single entry
|
||||
if isinstance(data["params"], list):
|
||||
param_sets = data["params"]
|
||||
else:
|
||||
param_sets = [data["params"]]
|
||||
|
||||
entries = data["entries"]
|
||||
|
||||
show_images = st.sidebar.checkbox("Show Images", False)
|
||||
show_images = st.sidebar.checkbox("Show Images", True)
|
||||
num_seeds = st.sidebar.number_input(
|
||||
"Num Seeds",
|
||||
value=1,
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
help="When > 1, increments the seed and runs multiple for each entry",
|
||||
)
|
||||
|
||||
# Optionally specify an output directory
|
||||
output_dir = st.sidebar.text_input("Output Directory", "")
|
||||
|
@ -83,26 +99,51 @@ def render() -> None:
|
|||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
st.write(f"#### Entry {i + 1} / {len(entries)}")
|
||||
# Write title cards for each param set
|
||||
title_cols = st.columns(len(param_sets))
|
||||
for i, params in enumerate(param_sets):
|
||||
col = title_cols[i]
|
||||
|
||||
if "name" not in params:
|
||||
params["name"] = f"params[{i}]"
|
||||
|
||||
col.write(f"## Param Set {i}")
|
||||
col.json(params)
|
||||
|
||||
for entry_i, entry in enumerate(entries):
|
||||
st.write("---")
|
||||
print(entry)
|
||||
prompt = entry["prompt"]
|
||||
negative_prompt = entry.get("negative_prompt", None)
|
||||
|
||||
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
|
||||
base_seed = entry.get("seed", 42)
|
||||
|
||||
text = f"##### ({base_seed}) {prompt}"
|
||||
if negative_prompt:
|
||||
text += f" \n**Negative**: {negative_prompt}"
|
||||
st.write(text)
|
||||
|
||||
for seed in range(base_seed, base_seed + num_seeds):
|
||||
cols = st.columns(len(param_sets))
|
||||
for i, params in enumerate(param_sets):
|
||||
col = cols[i]
|
||||
col.write(params["name"])
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=entry["prompt"],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=params.get("seed", 42),
|
||||
seed=seed,
|
||||
num_inference_steps=params.get("num_inference_steps", 50),
|
||||
guidance=params.get("guidance", 7.0),
|
||||
width=params.get("width", 512),
|
||||
checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT),
|
||||
scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]),
|
||||
height=512,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if show_images:
|
||||
st.image(image)
|
||||
col.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
p_spectrogram = SpectrogramParams(
|
||||
|
@ -117,18 +158,21 @@ def render() -> None:
|
|||
device=device,
|
||||
output_format=output_format,
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
col.audio(audio_bytes)
|
||||
|
||||
if output_path:
|
||||
prompt_slug = entry["prompt"].replace(" ", "_")
|
||||
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
|
||||
|
||||
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
||||
image_path = (
|
||||
output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
||||
)
|
||||
image.save(image_path, format="JPEG")
|
||||
entry["image_path"] = str(image_path)
|
||||
|
||||
audio_path = (
|
||||
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
||||
output_path
|
||||
/ f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
||||
)
|
||||
audio_path.write_bytes(audio_bytes.getbuffer())
|
||||
entry["audio_path"] = str(audio_path)
|
||||
|
|
|
@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
|
|||
return pipeline
|
||||
|
||||
|
||||
@st.cache_data
|
||||
@st.cache_data(persist=True)
|
||||
def run_txt2img(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
|
@ -281,12 +281,11 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
|
|||
"""
|
||||
Provide a custom model checkpoint.
|
||||
"""
|
||||
custom_checkpoint = container.text_input(
|
||||
return container.text_input(
|
||||
"Custom Checkpoint",
|
||||
value="",
|
||||
value=DEFAULT_CHECKPOINT,
|
||||
help="Provide a custom model checkpoint",
|
||||
)
|
||||
return custom_checkpoint or DEFAULT_CHECKPOINT
|
||||
|
||||
|
||||
@st.cache_data
|
||||
|
|
Loading…
Reference in New Issue