Audio download buttons and proper extension handling across the app
* Add buttons that download audio segments with the proper name, and display the name * Add a helper that displays the audio bar and the download button * Create a sidebar widget helper for choosing the output extension * Use this extension widget in all pages to dicate output types * Add a streamlit session state counter object to help with reruns * Improve UI in various places with small fixes Topic: audio_download_extensions_ui
This commit is contained in:
parent
8b07a5a45f
commit
75c67e1ea5
|
@ -1,5 +1,6 @@
|
|||
import io
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
|
@ -19,7 +20,7 @@ def render_audio_to_audio() -> None:
|
|||
st.subheader(":wave: Audio to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Modify existing audio from a text prompt.
|
||||
Modify existing audio from a text prompt or interpolate between two.
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -35,10 +36,15 @@ def render_audio_to_audio() -> None:
|
|||
modification. The best specific denoising depends on how different the prompt is
|
||||
from the source audio. You can play with the seed to get infinite variations.
|
||||
Currently the same seed is used for all clips along the track.
|
||||
|
||||
If the Interpolation check box is enabled, supports entering two sets of prompt,
|
||||
seed, and denoising value and smoothly blends between them along the selected
|
||||
duration of the audio. This is a great way to create a transition.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
num_inference_steps = T.cast(
|
||||
int,
|
||||
|
@ -55,7 +61,7 @@ def render_audio_to_audio() -> None:
|
|||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload audio",
|
||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
||||
type=streamlit_util.AUDIO_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
|
@ -89,7 +95,14 @@ def render_audio_to_audio() -> None:
|
|||
overlap_duration_s=overlap_duration_s,
|
||||
)
|
||||
|
||||
interpolate = st.checkbox("Interpolate between two settings", False)
|
||||
interpolate = st.checkbox(
|
||||
"Interpolate between two endpoints",
|
||||
value=False,
|
||||
help="Interpolate between two prompts, seeds, or denoising values along the"
|
||||
"duration of the segment",
|
||||
)
|
||||
|
||||
counter = streamlit_util.StreamlitCounter()
|
||||
|
||||
with st.form("audio to audio form"):
|
||||
if interpolate:
|
||||
|
@ -109,7 +122,7 @@ def render_audio_to_audio() -> None:
|
|||
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
|
||||
)
|
||||
|
||||
submit_button = st.form_submit_button("Riff", type="primary")
|
||||
st.form_submit_button("Riff", type="primary", on_click=counter.increment)
|
||||
|
||||
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
||||
show_difference = st.sidebar.checkbox("Show Difference", False)
|
||||
|
@ -124,9 +137,11 @@ def render_audio_to_audio() -> None:
|
|||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
if not submit_button:
|
||||
if counter.value == 0:
|
||||
return
|
||||
|
||||
st.write(f"## Counter: {counter.value}")
|
||||
|
||||
params = SpectrogramParams()
|
||||
|
||||
if interpolate:
|
||||
|
@ -228,16 +243,17 @@ def render_audio_to_audio() -> None:
|
|||
)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
diff_segment.export(audio_bytes, format="wav")
|
||||
diff_segment.export(audio_bytes, format=extension)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
# Combine clips with a crossfade based on overlap
|
||||
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
combined_segment.export(audio_bytes, format="mp3")
|
||||
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
|
||||
st.audio(audio_bytes, format="audio/mp3")
|
||||
|
||||
input_name = Path(audio_file.name).stem
|
||||
output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}"
|
||||
streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension)
|
||||
|
||||
|
||||
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
@ -29,10 +30,11 @@ def render_image_to_audio() -> None:
|
|||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
image_file = st.file_uploader(
|
||||
"Upload a file",
|
||||
type=["png", "jpg", "jpeg"],
|
||||
type=streamlit_util.IMAGE_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not image_file:
|
||||
|
@ -55,13 +57,17 @@ def render_image_to_audio() -> None:
|
|||
with st.expander("Spectrogram Parameters", expanded=False):
|
||||
st.json(dataclasses.asdict(params))
|
||||
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||
image=image.copy(),
|
||||
params=params,
|
||||
device=device,
|
||||
output_format="mp3",
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
streamlit_util.display_and_download_audio(
|
||||
segment,
|
||||
name=Path(image_file.name).stem,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -42,6 +42,7 @@ def render_interpolation() -> None:
|
|||
# Sidebar params
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
num_interpolation_steps = T.cast(
|
||||
int,
|
||||
|
@ -78,7 +79,7 @@ def render_interpolation() -> None:
|
|||
if init_image_name == "custom":
|
||||
init_image_file = st.sidebar.file_uploader(
|
||||
"Upload a custom seed image",
|
||||
type=["png", "jpg", "jpeg"],
|
||||
type=streamlit_util.IMAGE_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if init_image_file:
|
||||
|
@ -154,6 +155,7 @@ def render_interpolation() -> None:
|
|||
inputs=inputs,
|
||||
init_image=init_image,
|
||||
device=device,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
if show_individual_outputs:
|
||||
|
@ -167,19 +169,30 @@ def render_interpolation() -> None:
|
|||
|
||||
st.write("#### Final Output")
|
||||
|
||||
# TODO(hayk): Concatenate with better blending
|
||||
# TODO(hayk): Concatenate with overlap and better blending like in audio to audio
|
||||
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
|
||||
concat_segment = audio_segments[0]
|
||||
for segment in audio_segments[1:]:
|
||||
concat_segment = concat_segment.append(segment, crossfade=0)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
concat_segment.export(audio_bytes, format="mp3")
|
||||
concat_segment.export(audio_bytes, format=extension)
|
||||
audio_bytes.seek(0)
|
||||
|
||||
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
|
||||
st.audio(audio_bytes)
|
||||
|
||||
output_name = (
|
||||
f"{prompt_input_a.prompt.replace(' ', '_')}_"
|
||||
f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
|
||||
)
|
||||
st.download_button(
|
||||
output_name,
|
||||
data=audio_bytes,
|
||||
file_name=output_name,
|
||||
mime=f"audio/{extension}",
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_inputs(
|
||||
key: str,
|
||||
|
@ -222,7 +235,7 @@ def get_prompt_inputs(
|
|||
|
||||
@st.experimental_memo
|
||||
def run_interpolation(
|
||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
|
||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||
"""
|
||||
Cached function for riffusion interpolation.
|
||||
|
@ -250,7 +263,7 @@ def run_interpolation(
|
|||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
output_format="mp3",
|
||||
output_format=extension,
|
||||
)
|
||||
|
||||
return image, audio_bytes
|
||||
|
|
|
@ -6,6 +6,8 @@ import numpy as np
|
|||
import pydub
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_sample_clips() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
@ -28,7 +30,7 @@ def render_sample_clips() -> None:
|
|||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload a file",
|
||||
type=["wav", "mp3", "ogg"],
|
||||
type=streamlit_util.AUDIO_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not audio_file:
|
||||
|
@ -49,22 +51,26 @@ def render_sample_clips() -> None:
|
|||
)
|
||||
)
|
||||
|
||||
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
|
||||
duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000))
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
save_to_disk = st.sidebar.checkbox("Save to Disk", False)
|
||||
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
|
||||
num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3))
|
||||
extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"])
|
||||
assert extension is not None
|
||||
|
||||
# Optionally specify an output directory
|
||||
output_dir = st.text_input("Output Directory")
|
||||
if not output_dir:
|
||||
tmp_dir = tempfile.mkdtemp(prefix="sample_clips_")
|
||||
st.info(f"Specify an output directory. Suggested: `{tmp_dir}`")
|
||||
row = st.columns(4)
|
||||
num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3))
|
||||
duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000))
|
||||
seed = T.cast(int, row[2].number_input("Seed", value=42))
|
||||
|
||||
counter = streamlit_util.StreamlitCounter()
|
||||
st.button("Sample Clips", type="primary", on_click=counter.increment)
|
||||
if counter.value == 0:
|
||||
return
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
# Optionally pick an output directory
|
||||
if save_to_disk:
|
||||
output_dir = tempfile.mkdtemp(prefix="sample_clips_")
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
st.info(f"Output directory: `{output_dir}`")
|
||||
|
||||
if seed >= 0:
|
||||
np.random.seed(seed)
|
||||
|
@ -78,16 +84,22 @@ def render_sample_clips() -> None:
|
|||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms"
|
||||
|
||||
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
|
||||
|
||||
clip_path = output_path / clip_name
|
||||
clip.export(clip_path, format=extension)
|
||||
streamlit_util.display_and_download_audio(
|
||||
clip,
|
||||
name=clip_name,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
st.audio(str(clip_path))
|
||||
if save_to_disk:
|
||||
clip_path = output_path / f"clip_name.{extension}"
|
||||
clip.export(clip_path, format=extension)
|
||||
|
||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||
if save_to_disk:
|
||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import io
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import pydub
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.audio_splitter import split_audio
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def render_split_audio() -> None:
|
||||
|
@ -13,7 +15,7 @@ def render_split_audio() -> None:
|
|||
st.subheader(":scissors: Audio Splitter")
|
||||
st.write(
|
||||
"""
|
||||
Split an audio into stems of {vocals, drums, bass, piano, guitar, other}.
|
||||
Split audio into individual instrument stems.
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -32,13 +34,21 @@ def render_split_audio() -> None:
|
|||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"]
|
||||
extension = st.sidebar.selectbox(
|
||||
"Output format",
|
||||
options=extension_options,
|
||||
index=extension_options.index("mp3"),
|
||||
)
|
||||
assert extension is not None
|
||||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload audio",
|
||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
||||
type=extension_options,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
stem_options = ["vocals", "drums", "bass", "guitar", "piano", "other"]
|
||||
stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
|
||||
recombine = st.sidebar.multiselect(
|
||||
"Recombine",
|
||||
options=stem_options,
|
||||
|
@ -50,39 +60,45 @@ def render_split_audio() -> None:
|
|||
st.info("Upload audio to get started")
|
||||
return
|
||||
|
||||
st.write("#### original")
|
||||
# TODO(hayk): This might be bogus, it can be other formats..
|
||||
st.audio(audio_file, format="audio/mp3")
|
||||
st.write("#### Original")
|
||||
st.audio(audio_file)
|
||||
|
||||
if not st.button("Split", type="primary"):
|
||||
counter = streamlit_util.StreamlitCounter()
|
||||
st.button("Split", type="primary", on_click=counter.increment)
|
||||
if counter.value == 0:
|
||||
return
|
||||
|
||||
segment = streamlit_util.load_audio_file(audio_file)
|
||||
|
||||
# Split
|
||||
stems = split_audio(segment, device=device)
|
||||
stems = split_audio_cached(segment, device=device)
|
||||
|
||||
input_name = Path(audio_file.name).stem
|
||||
|
||||
# Display each
|
||||
for name, stem in stems.items():
|
||||
st.write(f"#### {name}")
|
||||
audio_bytes = io.BytesIO()
|
||||
stem.export(audio_bytes, format="mp3")
|
||||
st.audio(audio_bytes, format="audio/mp3")
|
||||
for name in stem_options:
|
||||
stem = stems[name.lower()]
|
||||
st.write(f"#### Stem: {name}")
|
||||
|
||||
output_name = f"{input_name}_{name.lower()}"
|
||||
streamlit_util.display_and_download_audio(stem, output_name, extension=extension)
|
||||
|
||||
if recombine:
|
||||
recombined: pydub.AudioSegment = None
|
||||
for name, stem in stems.items():
|
||||
if name in recombine:
|
||||
if recombined is None:
|
||||
recombined = stem
|
||||
else:
|
||||
recombined = recombined.overlay(stem)
|
||||
recombine_lower = [r.lower() for r in recombine]
|
||||
segments = [s for name, s in stems.items() if name in recombine_lower]
|
||||
recombined = audio_util.overlay_segments(segments)
|
||||
|
||||
# Display
|
||||
st.write("#### recombined")
|
||||
audio_bytes = io.BytesIO()
|
||||
recombined.export(audio_bytes, format="mp3")
|
||||
st.audio(audio_bytes, format="audio/mp3")
|
||||
st.write(f"#### Recombined: {', '.join(recombine)}")
|
||||
output_name = f"{input_name}_{'_'.join(recombine_lower)}"
|
||||
streamlit_util.display_and_download_audio(recombined, output_name, extension=extension)
|
||||
|
||||
|
||||
@st.cache
|
||||
def split_audio_cached(
|
||||
segment: pydub.AudioSegment, device: str = "cuda"
|
||||
) -> T.Dict[str, pydub.AudioSegment]:
|
||||
return split_audio(segment, device=device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -27,6 +27,7 @@ def render_text_to_audio() -> None:
|
|||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
with st.form("Inputs"):
|
||||
prompt = st.text_input("Prompt")
|
||||
|
@ -87,13 +88,15 @@ def render_text_to_audio() -> None:
|
|||
)
|
||||
st.image(image)
|
||||
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
output_format="mp3",
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
streamlit_util.display_and_download_audio(
|
||||
segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension
|
||||
)
|
||||
|
||||
seed += 1
|
||||
|
||||
|
|
|
@ -13,13 +13,13 @@ def render_main():
|
|||
st.write("Generate audio clips from text prompts.")
|
||||
|
||||
create_link(":wave: Audio to Audio", "/audio_to_audio")
|
||||
st.write("Upload audio and modify with text prompt.")
|
||||
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
||||
|
||||
create_link(":performing_arts: Interpolation", "/interpolation")
|
||||
st.write("Interpolate between prompts in the latent space.")
|
||||
|
||||
create_link(":scissors: Audio Splitter", "/split_audio")
|
||||
st.write("Upload audio and split into vocals, bass, drums, and other.")
|
||||
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
||||
|
||||
with right:
|
||||
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
||||
|
|
|
@ -17,6 +17,9 @@ from riffusion.spectrogram_params import SpectrogramParams
|
|||
|
||||
# TODO(hayk): Add URL params
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
||||
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_riffusion_checkpoint(
|
||||
|
@ -177,6 +180,20 @@ def select_device(container: T.Any = st.sidebar) -> str:
|
|||
return device
|
||||
|
||||
|
||||
def select_audio_extension(container: T.Any = st.sidebar) -> str:
|
||||
"""
|
||||
Dropdown to select an audio extension, with an intelligent default.
|
||||
"""
|
||||
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
|
||||
extension = container.selectbox(
|
||||
"Output format",
|
||||
options=AUDIO_EXTENSIONS,
|
||||
index=AUDIO_EXTENSIONS.index(default),
|
||||
)
|
||||
assert extension is not None
|
||||
return extension
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
||||
return pydub.AudioSegment.from_file(audio_file)
|
||||
|
@ -224,3 +241,43 @@ def run_img2img(
|
|||
)
|
||||
|
||||
return result.images[0]
|
||||
|
||||
|
||||
class StreamlitCounter:
|
||||
"""
|
||||
Simple counter stored in streamlit session state.
|
||||
"""
|
||||
|
||||
def __init__(self, key="_counter"):
|
||||
self.key = key
|
||||
if not st.session_state.get(self.key):
|
||||
st.session_state[self.key] = 0
|
||||
|
||||
def increment(self):
|
||||
st.session_state[self.key] += 1
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return st.session_state[self.key]
|
||||
|
||||
|
||||
def display_and_download_audio(
|
||||
segment: pydub.AudioSegment,
|
||||
name: str,
|
||||
extension: str = "mp3",
|
||||
) -> None:
|
||||
"""
|
||||
Display the given audio segment and provide a button to download it with
|
||||
a proper file name, since st.audio doesn't support that.
|
||||
"""
|
||||
mime_type = f"audio/{extension}"
|
||||
audio_bytes = io.BytesIO()
|
||||
segment.export(audio_bytes, format=extension)
|
||||
st.audio(audio_bytes, format=mime_type)
|
||||
|
||||
st.download_button(
|
||||
f"{name}.{extension}",
|
||||
data=audio_bytes,
|
||||
file_name=f"{name}.{extension}",
|
||||
mime=mime_type,
|
||||
)
|
||||
|
|
|
@ -83,3 +83,17 @@ def stitch_segments(
|
|||
for segment in segments[1:]:
|
||||
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
||||
return combined_segment
|
||||
|
||||
|
||||
def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment:
|
||||
"""
|
||||
Overlay a sequence of audio segments on top of each other.
|
||||
"""
|
||||
assert len(segments) > 0
|
||||
output: pydub.AudioSegment = None
|
||||
for segment in segments:
|
||||
if output is None:
|
||||
output = segment
|
||||
else:
|
||||
output = output.overlay(segment)
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue