Streamlit fixes and batch text to audio with multiple models

Topic: batch_text_with_multiple_models
This commit is contained in:
Hayk Martiros 2023-04-11 04:26:30 +00:00
parent 8d44cbbc64
commit a861b48232
5 changed files with 118 additions and 53 deletions

View File

@ -43,7 +43,9 @@ def render() -> None:
device = streamlit_util.select_device(st.sidebar) device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar)
checkpoint = streamlit_util.select_checkpoint(st.sidebar)
use_20k = st.sidebar.checkbox("Use 20kHz", value=False)
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False) use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
with st.sidebar: with st.sidebar:
@ -183,7 +185,19 @@ def render() -> None:
st.write(f"## Counter: {counter.value}") st.write(f"## Counter: {counter.value}")
params = SpectrogramParams() if use_20k:
params = SpectrogramParams(
min_frequency=10,
max_frequency=20000,
sample_rate=44100,
stereo=True,
)
else:
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
stereo=False,
)
if interpolate: if interpolate:
# TODO(hayk): Make not linspace # TODO(hayk): Make not linspace
@ -237,6 +251,7 @@ def render() -> None:
inputs=inputs, inputs=inputs,
init_image=init_image_resized, init_image=init_image_resized,
device=device, device=device,
checkpoint=checkpoint,
) )
elif use_magic_mix: elif use_magic_mix:
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix" assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
@ -251,6 +266,7 @@ def render() -> None:
mix_factor=magic_mix_mix_factor, mix_factor=magic_mix_mix_factor,
device=device, device=device,
scheduler=scheduler, scheduler=scheduler,
checkpoint=checkpoint,
) )
else: else:
image = streamlit_util.run_img2img( image = streamlit_util.run_img2img(
@ -264,6 +280,7 @@ def render() -> None:
progress_callback=progress_callback, progress_callback=progress_callback,
device=device, device=device,
scheduler=scheduler, scheduler=scheduler,
checkpoint=checkpoint,
) )
# Resize back to original size # Resize back to original size

View File

@ -241,13 +241,18 @@ def get_prompt_inputs(
@st.cache_data @st.cache_data
def run_interpolation( def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3" inputs: InferenceInput,
init_image: Image.Image,
checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
device: str = "cuda",
extension: str = "mp3",
) -> T.Tuple[Image.Image, io.BytesIO]: ) -> T.Tuple[Image.Image, io.BytesIO]:
""" """
Cached function for riffusion interpolation. Cached function for riffusion interpolation.
""" """
pipeline = streamlit_util.load_riffusion_checkpoint( pipeline = streamlit_util.load_riffusion_checkpoint(
device=device, device=device,
checkpoint=checkpoint,
# No trace so we can have variable width # No trace so we can have variable width
no_traced_unet=True, no_traced_unet=True,
) )

View File

@ -55,7 +55,7 @@ def render() -> None:
st.form_submit_button("Riff", type="primary") st.form_submit_button("Riff", type="primary")
with st.sidebar: with st.sidebar:
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=25)) num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30))
width = T.cast(int, st.number_input("Width", value=512)) width = T.cast(int, st.number_input("Width", value=512))
guidance = st.number_input( guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt" "Guidance", value=7.0, help="How much the model listens to the text prompt"

View File

@ -11,21 +11,25 @@ from riffusion.streamlit import util as streamlit_util
EXAMPLE_INPUT = """ EXAMPLE_INPUT = """
{ {
"params": { "params": {
"seed": 42, "checkpoint": "riffusion/riffusion-model-v1",
"scheduler": "DPMSolverMultistepScheduler",
"num_inference_steps": 50, "num_inference_steps": 50,
"guidance": 7.0, "guidance": 7.0,
"width": 512 "width": 512,
}, },
"entries": [ "entries": [
{ {
"prompt": "Church bells" "prompt": "Church bells",
"seed": 42
}, },
{ {
"prompt": "electronic beats", "prompt": "electronic beats",
"negative_prompt": "drums" "negative_prompt": "drums",
"seed": 100
}, },
{ {
"prompt": "classical violin concerto" "prompt": "classical violin concerto",
"seed": 4
} }
] ]
} }
@ -71,10 +75,22 @@ def render() -> None:
with st.expander("Input Data", expanded=False): with st.expander("Input Data", expanded=False):
st.json(data) st.json(data)
params = data["params"] # Params can either be a list or a single entry
if isinstance(data["params"], list):
param_sets = data["params"]
else:
param_sets = [data["params"]]
entries = data["entries"] entries = data["entries"]
show_images = st.sidebar.checkbox("Show Images", False) show_images = st.sidebar.checkbox("Show Images", True)
num_seeds = st.sidebar.number_input(
"Num Seeds",
value=1,
min_value=1,
max_value=10,
help="When > 1, increments the seed and runs multiple for each entry",
)
# Optionally specify an output directory # Optionally specify an output directory
output_dir = st.sidebar.text_input("Output Directory", "") output_dir = st.sidebar.text_input("Output Directory", "")
@ -83,26 +99,51 @@ def render() -> None:
output_path = Path(output_dir) output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
for i, entry in enumerate(entries): # Write title cards for each param set
st.write(f"#### Entry {i + 1} / {len(entries)}") title_cols = st.columns(len(param_sets))
for i, params in enumerate(param_sets):
col = title_cols[i]
if "name" not in params:
params["name"] = f"params[{i}]"
col.write(f"## Param Set {i}")
col.json(params)
for entry_i, entry in enumerate(entries):
st.write("---")
print(entry)
prompt = entry["prompt"]
negative_prompt = entry.get("negative_prompt", None) negative_prompt = entry.get("negative_prompt", None)
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}") base_seed = entry.get("seed", 42)
text = f"##### ({base_seed}) {prompt}"
if negative_prompt:
text += f" \n**Negative**: {negative_prompt}"
st.write(text)
for seed in range(base_seed, base_seed + num_seeds):
cols = st.columns(len(param_sets))
for i, params in enumerate(param_sets):
col = cols[i]
col.write(params["name"])
image = streamlit_util.run_txt2img( image = streamlit_util.run_txt2img(
prompt=entry["prompt"], prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
seed=params.get("seed", 42), seed=seed,
num_inference_steps=params.get("num_inference_steps", 50), num_inference_steps=params.get("num_inference_steps", 50),
guidance=params.get("guidance", 7.0), guidance=params.get("guidance", 7.0),
width=params.get("width", 512), width=params.get("width", 512),
checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT),
scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]),
height=512, height=512,
device=device, device=device,
) )
if show_images: if show_images:
st.image(image) col.image(image)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
p_spectrogram = SpectrogramParams( p_spectrogram = SpectrogramParams(
@ -117,18 +158,21 @@ def render() -> None:
device=device, device=device,
output_format=output_format, output_format=output_format,
) )
st.audio(audio_bytes) col.audio(audio_bytes)
if output_path: if output_path:
prompt_slug = entry["prompt"].replace(" ", "_") prompt_slug = entry["prompt"].replace(" ", "_")
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_") negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg" image_path = (
output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
)
image.save(image_path, format="JPEG") image.save(image_path, format="JPEG")
entry["image_path"] = str(image_path) entry["image_path"] = str(image_path)
audio_path = ( audio_path = (
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}" output_path
/ f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
) )
audio_path.write_bytes(audio_bytes.getbuffer()) audio_path.write_bytes(audio_bytes.getbuffer())
entry["audio_path"] = str(audio_path) entry["audio_path"] = str(audio_path)

View File

@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
return pipeline return pipeline
@st.cache_data @st.cache_data(persist=True)
def run_txt2img( def run_txt2img(
prompt: str, prompt: str,
num_inference_steps: int, num_inference_steps: int,
@ -281,12 +281,11 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
""" """
Provide a custom model checkpoint. Provide a custom model checkpoint.
""" """
custom_checkpoint = container.text_input( return container.text_input(
"Custom Checkpoint", "Custom Checkpoint",
value="", value=DEFAULT_CHECKPOINT,
help="Provide a custom model checkpoint", help="Provide a custom model checkpoint",
) )
return custom_checkpoint or DEFAULT_CHECKPOINT
@st.cache_data @st.cache_data