Audio download buttons and proper extension handling across the app

* Add buttons that download audio segments with the proper name, and display the name
 * Add a helper that displays the audio bar and the download button
 * Create a sidebar widget helper for choosing the output extension
 * Use this extension widget in all pages to dicate output types
 * Add a streamlit session state counter object to help with reruns
 * Improve UI in various places with small fixes

Topic: audio_download_extensions_ui
This commit is contained in:
Hayk Martiros 2023-01-14 21:59:36 +00:00
parent 8b07a5a45f
commit 75c67e1ea5
9 changed files with 203 additions and 66 deletions

View File

@ -1,5 +1,6 @@
import io
import typing as T
from pathlib import Path
import numpy as np
import pydub
@ -19,7 +20,7 @@ def render_audio_to_audio() -> None:
st.subheader(":wave: Audio to Audio")
st.write(
"""
Modify existing audio from a text prompt.
Modify existing audio from a text prompt or interpolate between two.
"""
)
@ -35,10 +36,15 @@ def render_audio_to_audio() -> None:
modification. The best specific denoising depends on how different the prompt is
from the source audio. You can play with the seed to get infinite variations.
Currently the same seed is used for all clips along the track.
If the Interpolation check box is enabled, supports entering two sets of prompt,
seed, and denoising value and smoothly blends between them along the selected
duration of the audio. This is a great way to create a transition.
"""
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
num_inference_steps = T.cast(
int,
@ -55,7 +61,7 @@ def render_audio_to_audio() -> None:
audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
type=streamlit_util.AUDIO_EXTENSIONS,
label_visibility="collapsed",
)
@ -89,7 +95,14 @@ def render_audio_to_audio() -> None:
overlap_duration_s=overlap_duration_s,
)
interpolate = st.checkbox("Interpolate between two settings", False)
interpolate = st.checkbox(
"Interpolate between two endpoints",
value=False,
help="Interpolate between two prompts, seeds, or denoising values along the"
"duration of the segment",
)
counter = streamlit_util.StreamlitCounter()
with st.form("audio to audio form"):
if interpolate:
@ -109,7 +122,7 @@ def render_audio_to_audio() -> None:
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
)
submit_button = st.form_submit_button("Riff", type="primary")
st.form_submit_button("Riff", type="primary", on_click=counter.increment)
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)
@ -124,9 +137,11 @@ def render_audio_to_audio() -> None:
st.info("Enter a prompt")
return
if not submit_button:
if counter.value == 0:
return
st.write(f"## Counter: {counter.value}")
params = SpectrogramParams()
if interpolate:
@ -228,16 +243,17 @@ def render_audio_to_audio() -> None:
)
audio_bytes = io.BytesIO()
diff_segment.export(audio_bytes, format="wav")
diff_segment.export(audio_bytes, format=extension)
st.audio(audio_bytes)
# Combine clips with a crossfade based on overlap
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
audio_bytes = io.BytesIO()
combined_segment.export(audio_bytes, format="mp3")
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
st.audio(audio_bytes, format="audio/mp3")
input_name = Path(audio_file.name).stem
output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}"
streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension)
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:

View File

@ -1,4 +1,5 @@
import dataclasses
from pathlib import Path
import streamlit as st
from PIL import Image
@ -29,10 +30,11 @@ def render_image_to_audio() -> None:
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
image_file = st.file_uploader(
"Upload a file",
type=["png", "jpg", "jpeg"],
type=streamlit_util.IMAGE_EXTENSIONS,
label_visibility="collapsed",
)
if not image_file:
@ -55,13 +57,17 @@ def render_image_to_audio() -> None:
with st.expander("Spectrogram Parameters", expanded=False):
st.json(dataclasses.asdict(params))
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
segment = streamlit_util.audio_segment_from_spectrogram_image(
image=image.copy(),
params=params,
device=device,
output_format="mp3",
)
st.audio(audio_bytes)
streamlit_util.display_and_download_audio(
segment,
name=Path(image_file.name).stem,
extension=extension,
)
if __name__ == "__main__":

View File

@ -42,6 +42,7 @@ def render_interpolation() -> None:
# Sidebar params
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
num_interpolation_steps = T.cast(
int,
@ -78,7 +79,7 @@ def render_interpolation() -> None:
if init_image_name == "custom":
init_image_file = st.sidebar.file_uploader(
"Upload a custom seed image",
type=["png", "jpg", "jpeg"],
type=streamlit_util.IMAGE_EXTENSIONS,
label_visibility="collapsed",
)
if init_image_file:
@ -154,6 +155,7 @@ def render_interpolation() -> None:
inputs=inputs,
init_image=init_image,
device=device,
extension=extension,
)
if show_individual_outputs:
@ -167,19 +169,30 @@ def render_interpolation() -> None:
st.write("#### Final Output")
# TODO(hayk): Concatenate with better blending
# TODO(hayk): Concatenate with overlap and better blending like in audio to audio
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
concat_segment = audio_segments[0]
for segment in audio_segments[1:]:
concat_segment = concat_segment.append(segment, crossfade=0)
audio_bytes = io.BytesIO()
concat_segment.export(audio_bytes, format="mp3")
concat_segment.export(audio_bytes, format=extension)
audio_bytes.seek(0)
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
st.audio(audio_bytes)
output_name = (
f"{prompt_input_a.prompt.replace(' ', '_')}_"
f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
)
st.download_button(
output_name,
data=audio_bytes,
file_name=output_name,
mime=f"audio/{extension}",
)
def get_prompt_inputs(
key: str,
@ -222,7 +235,7 @@ def get_prompt_inputs(
@st.experimental_memo
def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
) -> T.Tuple[Image.Image, io.BytesIO]:
"""
Cached function for riffusion interpolation.
@ -250,7 +263,7 @@ def run_interpolation(
image=image,
params=params,
device=device,
output_format="mp3",
output_format=extension,
)
return image, audio_bytes

View File

@ -6,6 +6,8 @@ import numpy as np
import pydub
import streamlit as st
from riffusion.streamlit import util as streamlit_util
def render_sample_clips() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
@ -28,7 +30,7 @@ def render_sample_clips() -> None:
audio_file = st.file_uploader(
"Upload a file",
type=["wav", "mp3", "ogg"],
type=streamlit_util.AUDIO_EXTENSIONS,
label_visibility="collapsed",
)
if not audio_file:
@ -49,22 +51,26 @@ def render_sample_clips() -> None:
)
)
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000))
extension = streamlit_util.select_audio_extension(st.sidebar)
save_to_disk = st.sidebar.checkbox("Save to Disk", False)
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3))
extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"])
assert extension is not None
# Optionally specify an output directory
output_dir = st.text_input("Output Directory")
if not output_dir:
tmp_dir = tempfile.mkdtemp(prefix="sample_clips_")
st.info(f"Specify an output directory. Suggested: `{tmp_dir}`")
row = st.columns(4)
num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3))
duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000))
seed = T.cast(int, row[2].number_input("Seed", value=42))
counter = streamlit_util.StreamlitCounter()
st.button("Sample Clips", type="primary", on_click=counter.increment)
if counter.value == 0:
return
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Optionally pick an output directory
if save_to_disk:
output_dir = tempfile.mkdtemp(prefix="sample_clips_")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
st.info(f"Output directory: `{output_dir}`")
if seed >= 0:
np.random.seed(seed)
@ -78,16 +84,22 @@ def render_sample_clips() -> None:
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms"
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
clip_path = output_path / clip_name
clip.export(clip_path, format=extension)
streamlit_util.display_and_download_audio(
clip,
name=clip_name,
extension=extension,
)
st.audio(str(clip_path))
if save_to_disk:
clip_path = output_path / f"clip_name.{extension}"
clip.export(clip_path, format=extension)
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
if save_to_disk:
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
if __name__ == "__main__":

View File

@ -1,10 +1,12 @@
import io
import typing as T
from pathlib import Path
import pydub
import streamlit as st
from riffusion.audio_splitter import split_audio
from riffusion.streamlit import util as streamlit_util
from riffusion.util import audio_util
def render_split_audio() -> None:
@ -13,7 +15,7 @@ def render_split_audio() -> None:
st.subheader(":scissors: Audio Splitter")
st.write(
"""
Split an audio into stems of {vocals, drums, bass, piano, guitar, other}.
Split audio into individual instrument stems.
"""
)
@ -32,13 +34,21 @@ def render_split_audio() -> None:
device = streamlit_util.select_device(st.sidebar)
extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"]
extension = st.sidebar.selectbox(
"Output format",
options=extension_options,
index=extension_options.index("mp3"),
)
assert extension is not None
audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
type=extension_options,
label_visibility="collapsed",
)
stem_options = ["vocals", "drums", "bass", "guitar", "piano", "other"]
stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
recombine = st.sidebar.multiselect(
"Recombine",
options=stem_options,
@ -50,39 +60,45 @@ def render_split_audio() -> None:
st.info("Upload audio to get started")
return
st.write("#### original")
# TODO(hayk): This might be bogus, it can be other formats..
st.audio(audio_file, format="audio/mp3")
st.write("#### Original")
st.audio(audio_file)
if not st.button("Split", type="primary"):
counter = streamlit_util.StreamlitCounter()
st.button("Split", type="primary", on_click=counter.increment)
if counter.value == 0:
return
segment = streamlit_util.load_audio_file(audio_file)
# Split
stems = split_audio(segment, device=device)
stems = split_audio_cached(segment, device=device)
input_name = Path(audio_file.name).stem
# Display each
for name, stem in stems.items():
st.write(f"#### {name}")
audio_bytes = io.BytesIO()
stem.export(audio_bytes, format="mp3")
st.audio(audio_bytes, format="audio/mp3")
for name in stem_options:
stem = stems[name.lower()]
st.write(f"#### Stem: {name}")
output_name = f"{input_name}_{name.lower()}"
streamlit_util.display_and_download_audio(stem, output_name, extension=extension)
if recombine:
recombined: pydub.AudioSegment = None
for name, stem in stems.items():
if name in recombine:
if recombined is None:
recombined = stem
else:
recombined = recombined.overlay(stem)
recombine_lower = [r.lower() for r in recombine]
segments = [s for name, s in stems.items() if name in recombine_lower]
recombined = audio_util.overlay_segments(segments)
# Display
st.write("#### recombined")
audio_bytes = io.BytesIO()
recombined.export(audio_bytes, format="mp3")
st.audio(audio_bytes, format="audio/mp3")
st.write(f"#### Recombined: {', '.join(recombine)}")
output_name = f"{input_name}_{'_'.join(recombine_lower)}"
streamlit_util.display_and_download_audio(recombined, output_name, extension=extension)
@st.cache
def split_audio_cached(
segment: pydub.AudioSegment, device: str = "cuda"
) -> T.Dict[str, pydub.AudioSegment]:
return split_audio(segment, device=device)
if __name__ == "__main__":

View File

@ -27,6 +27,7 @@ def render_text_to_audio() -> None:
)
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
with st.form("Inputs"):
prompt = st.text_input("Prompt")
@ -87,13 +88,15 @@ def render_text_to_audio() -> None:
)
st.image(image)
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
segment = streamlit_util.audio_segment_from_spectrogram_image(
image=image,
params=params,
device=device,
output_format="mp3",
)
st.audio(audio_bytes)
streamlit_util.display_and_download_audio(
segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension
)
seed += 1

View File

@ -13,13 +13,13 @@ def render_main():
st.write("Generate audio clips from text prompts.")
create_link(":wave: Audio to Audio", "/audio_to_audio")
st.write("Upload audio and modify with text prompt.")
st.write("Upload audio and modify with text prompt (interpolation supported).")
create_link(":performing_arts: Interpolation", "/interpolation")
st.write("Interpolate between prompts in the latent space.")
create_link(":scissors: Audio Splitter", "/split_audio")
st.write("Upload audio and split into vocals, bass, drums, and other.")
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
with right:
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")

View File

@ -17,6 +17,9 @@ from riffusion.spectrogram_params import SpectrogramParams
# TODO(hayk): Add URL params
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
@st.experimental_singleton
def load_riffusion_checkpoint(
@ -177,6 +180,20 @@ def select_device(container: T.Any = st.sidebar) -> str:
return device
def select_audio_extension(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select an audio extension, with an intelligent default.
"""
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
extension = container.selectbox(
"Output format",
options=AUDIO_EXTENSIONS,
index=AUDIO_EXTENSIONS.index(default),
)
assert extension is not None
return extension
@st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)
@ -224,3 +241,43 @@ def run_img2img(
)
return result.images[0]
class StreamlitCounter:
"""
Simple counter stored in streamlit session state.
"""
def __init__(self, key="_counter"):
self.key = key
if not st.session_state.get(self.key):
st.session_state[self.key] = 0
def increment(self):
st.session_state[self.key] += 1
@property
def value(self):
return st.session_state[self.key]
def display_and_download_audio(
segment: pydub.AudioSegment,
name: str,
extension: str = "mp3",
) -> None:
"""
Display the given audio segment and provide a button to download it with
a proper file name, since st.audio doesn't support that.
"""
mime_type = f"audio/{extension}"
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=extension)
st.audio(audio_bytes, format=mime_type)
st.download_button(
f"{name}.{extension}",
data=audio_bytes,
file_name=f"{name}.{extension}",
mime=mime_type,
)

View File

@ -83,3 +83,17 @@ def stitch_segments(
for segment in segments[1:]:
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
return combined_segment
def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment:
"""
Overlay a sequence of audio segments on top of each other.
"""
assert len(segments) > 0
output: pydub.AudioSegment = None
for segment in segments:
if output is None:
output = segment
else:
output = output.overlay(segment)
return output