Streamlit fixes and batch text to audio with multiple models

Topic: batch_text_with_multiple_models
This commit is contained in:
Hayk Martiros 2023-04-11 04:26:30 +00:00
parent 8d44cbbc64
commit a861b48232
5 changed files with 118 additions and 53 deletions

View File

@ -43,7 +43,9 @@ def render() -> None:
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
checkpoint = streamlit_util.select_checkpoint(st.sidebar)
use_20k = st.sidebar.checkbox("Use 20kHz", value=False)
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
with st.sidebar:
@ -183,7 +185,19 @@ def render() -> None:
st.write(f"## Counter: {counter.value}")
params = SpectrogramParams()
if use_20k:
params = SpectrogramParams(
min_frequency=10,
max_frequency=20000,
sample_rate=44100,
stereo=True,
)
else:
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
stereo=False,
)
if interpolate:
# TODO(hayk): Make not linspace
@ -237,6 +251,7 @@ def render() -> None:
inputs=inputs,
init_image=init_image_resized,
device=device,
checkpoint=checkpoint,
)
elif use_magic_mix:
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
@ -251,6 +266,7 @@ def render() -> None:
mix_factor=magic_mix_mix_factor,
device=device,
scheduler=scheduler,
checkpoint=checkpoint,
)
else:
image = streamlit_util.run_img2img(
@ -264,6 +280,7 @@ def render() -> None:
progress_callback=progress_callback,
device=device,
scheduler=scheduler,
checkpoint=checkpoint,
)
# Resize back to original size

View File

@ -241,13 +241,18 @@ def get_prompt_inputs(
@st.cache_data
def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
inputs: InferenceInput,
init_image: Image.Image,
checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
device: str = "cuda",
extension: str = "mp3",
) -> T.Tuple[Image.Image, io.BytesIO]:
"""
Cached function for riffusion interpolation.
"""
pipeline = streamlit_util.load_riffusion_checkpoint(
device=device,
checkpoint=checkpoint,
# No trace so we can have variable width
no_traced_unet=True,
)

View File

@ -55,7 +55,7 @@ def render() -> None:
st.form_submit_button("Riff", type="primary")
with st.sidebar:
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=25))
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30))
width = T.cast(int, st.number_input("Width", value=512))
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"

View File

@ -11,21 +11,25 @@ from riffusion.streamlit import util as streamlit_util
EXAMPLE_INPUT = """
{
"params": {
"seed": 42,
"checkpoint": "riffusion/riffusion-model-v1",
"scheduler": "DPMSolverMultistepScheduler",
"num_inference_steps": 50,
"guidance": 7.0,
"width": 512
"width": 512,
},
"entries": [
{
"prompt": "Church bells"
"prompt": "Church bells",
"seed": 42
},
{
"prompt": "electronic beats",
"negative_prompt": "drums"
"negative_prompt": "drums",
"seed": 100
},
{
"prompt": "classical violin concerto"
"prompt": "classical violin concerto",
"seed": 4
}
]
}
@ -71,10 +75,22 @@ def render() -> None:
with st.expander("Input Data", expanded=False):
st.json(data)
params = data["params"]
# Params can either be a list or a single entry
if isinstance(data["params"], list):
param_sets = data["params"]
else:
param_sets = [data["params"]]
entries = data["entries"]
show_images = st.sidebar.checkbox("Show Images", False)
show_images = st.sidebar.checkbox("Show Images", True)
num_seeds = st.sidebar.number_input(
"Num Seeds",
value=1,
min_value=1,
max_value=10,
help="When > 1, increments the seed and runs multiple for each entry",
)
# Optionally specify an output directory
output_dir = st.sidebar.text_input("Output Directory", "")
@ -83,55 +99,83 @@ def render() -> None:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for i, entry in enumerate(entries):
st.write(f"#### Entry {i + 1} / {len(entries)}")
# Write title cards for each param set
title_cols = st.columns(len(param_sets))
for i, params in enumerate(param_sets):
col = title_cols[i]
if "name" not in params:
params["name"] = f"params[{i}]"
col.write(f"## Param Set {i}")
col.json(params)
for entry_i, entry in enumerate(entries):
st.write("---")
print(entry)
prompt = entry["prompt"]
negative_prompt = entry.get("negative_prompt", None)
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
base_seed = entry.get("seed", 42)
image = streamlit_util.run_txt2img(
prompt=entry["prompt"],
negative_prompt=negative_prompt,
seed=params.get("seed", 42),
num_inference_steps=params.get("num_inference_steps", 50),
guidance=params.get("guidance", 7.0),
width=params.get("width", 512),
height=512,
device=device,
)
text = f"##### ({base_seed}) {prompt}"
if negative_prompt:
text += f" \n**Negative**: {negative_prompt}"
st.write(text)
if show_images:
st.image(image)
for seed in range(base_seed, base_seed + num_seeds):
cols = st.columns(len(param_sets))
for i, params in enumerate(param_sets):
col = cols[i]
col.write(params["name"])
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
p_spectrogram = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
image = streamlit_util.run_txt2img(
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
num_inference_steps=params.get("num_inference_steps", 50),
guidance=params.get("guidance", 7.0),
width=params.get("width", 512),
checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT),
scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]),
height=512,
device=device,
)
output_format = "mp3"
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=p_spectrogram,
device=device,
output_format=output_format,
)
st.audio(audio_bytes)
if show_images:
col.image(image)
if output_path:
prompt_slug = entry["prompt"].replace(" ", "_")
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
p_spectrogram = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
image.save(image_path, format="JPEG")
entry["image_path"] = str(image_path)
output_format = "mp3"
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
image=image,
params=p_spectrogram,
device=device,
output_format=output_format,
)
col.audio(audio_bytes)
audio_path = (
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
)
audio_path.write_bytes(audio_bytes.getbuffer())
entry["audio_path"] = str(audio_path)
if output_path:
prompt_slug = entry["prompt"].replace(" ", "_")
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
image_path = (
output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
)
image.save(image_path, format="JPEG")
entry["image_path"] = str(image_path)
audio_path = (
output_path
/ f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
)
audio_path.write_bytes(audio_bytes.getbuffer())
entry["audio_path"] = str(audio_path)
if output_path:
output_json_path = output_path / "index.json"

View File

@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
return pipeline
@st.cache_data
@st.cache_data(persist=True)
def run_txt2img(
prompt: str,
num_inference_steps: int,
@ -281,12 +281,11 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
"""
Provide a custom model checkpoint.
"""
custom_checkpoint = container.text_input(
return container.text_input(
"Custom Checkpoint",
value="",
value=DEFAULT_CHECKPOINT,
help="Provide a custom model checkpoint",
)
return custom_checkpoint or DEFAULT_CHECKPOINT
@st.cache_data