Audio download buttons and proper extension handling across the app
* Add buttons that download audio segments with the proper name, and display the name * Add a helper that displays the audio bar and the download button * Create a sidebar widget helper for choosing the output extension * Use this extension widget in all pages to dicate output types * Add a streamlit session state counter object to help with reruns * Improve UI in various places with small fixes Topic: audio_download_extensions_ui
This commit is contained in:
parent
8b07a5a45f
commit
75c67e1ea5
|
@ -1,5 +1,6 @@
|
||||||
import io
|
import io
|
||||||
import typing as T
|
import typing as T
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pydub
|
import pydub
|
||||||
|
@ -19,7 +20,7 @@ def render_audio_to_audio() -> None:
|
||||||
st.subheader(":wave: Audio to Audio")
|
st.subheader(":wave: Audio to Audio")
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Modify existing audio from a text prompt.
|
Modify existing audio from a text prompt or interpolate between two.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,10 +36,15 @@ def render_audio_to_audio() -> None:
|
||||||
modification. The best specific denoising depends on how different the prompt is
|
modification. The best specific denoising depends on how different the prompt is
|
||||||
from the source audio. You can play with the seed to get infinite variations.
|
from the source audio. You can play with the seed to get infinite variations.
|
||||||
Currently the same seed is used for all clips along the track.
|
Currently the same seed is used for all clips along the track.
|
||||||
|
|
||||||
|
If the Interpolation check box is enabled, supports entering two sets of prompt,
|
||||||
|
seed, and denoising value and smoothly blends between them along the selected
|
||||||
|
duration of the audio. This is a great way to create a transition.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
|
|
||||||
num_inference_steps = T.cast(
|
num_inference_steps = T.cast(
|
||||||
int,
|
int,
|
||||||
|
@ -55,7 +61,7 @@ def render_audio_to_audio() -> None:
|
||||||
|
|
||||||
audio_file = st.file_uploader(
|
audio_file = st.file_uploader(
|
||||||
"Upload audio",
|
"Upload audio",
|
||||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
type=streamlit_util.AUDIO_EXTENSIONS,
|
||||||
label_visibility="collapsed",
|
label_visibility="collapsed",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,7 +95,14 @@ def render_audio_to_audio() -> None:
|
||||||
overlap_duration_s=overlap_duration_s,
|
overlap_duration_s=overlap_duration_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
interpolate = st.checkbox("Interpolate between two settings", False)
|
interpolate = st.checkbox(
|
||||||
|
"Interpolate between two endpoints",
|
||||||
|
value=False,
|
||||||
|
help="Interpolate between two prompts, seeds, or denoising values along the"
|
||||||
|
"duration of the segment",
|
||||||
|
)
|
||||||
|
|
||||||
|
counter = streamlit_util.StreamlitCounter()
|
||||||
|
|
||||||
with st.form("audio to audio form"):
|
with st.form("audio to audio form"):
|
||||||
if interpolate:
|
if interpolate:
|
||||||
|
@ -109,7 +122,7 @@ def render_audio_to_audio() -> None:
|
||||||
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
|
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
submit_button = st.form_submit_button("Riff", type="primary")
|
st.form_submit_button("Riff", type="primary", on_click=counter.increment)
|
||||||
|
|
||||||
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
||||||
show_difference = st.sidebar.checkbox("Show Difference", False)
|
show_difference = st.sidebar.checkbox("Show Difference", False)
|
||||||
|
@ -124,9 +137,11 @@ def render_audio_to_audio() -> None:
|
||||||
st.info("Enter a prompt")
|
st.info("Enter a prompt")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not submit_button:
|
if counter.value == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
st.write(f"## Counter: {counter.value}")
|
||||||
|
|
||||||
params = SpectrogramParams()
|
params = SpectrogramParams()
|
||||||
|
|
||||||
if interpolate:
|
if interpolate:
|
||||||
|
@ -228,16 +243,17 @@ def render_audio_to_audio() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
audio_bytes = io.BytesIO()
|
||||||
diff_segment.export(audio_bytes, format="wav")
|
diff_segment.export(audio_bytes, format=extension)
|
||||||
st.audio(audio_bytes)
|
st.audio(audio_bytes)
|
||||||
|
|
||||||
# Combine clips with a crossfade based on overlap
|
# Combine clips with a crossfade based on overlap
|
||||||
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
|
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
|
||||||
combined_segment.export(audio_bytes, format="mp3")
|
|
||||||
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
|
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
|
||||||
st.audio(audio_bytes, format="audio/mp3")
|
|
||||||
|
input_name = Path(audio_file.name).stem
|
||||||
|
output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}"
|
||||||
|
streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension)
|
||||||
|
|
||||||
|
|
||||||
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
|
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -29,10 +30,11 @@ def render_image_to_audio() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
|
|
||||||
image_file = st.file_uploader(
|
image_file = st.file_uploader(
|
||||||
"Upload a file",
|
"Upload a file",
|
||||||
type=["png", "jpg", "jpeg"],
|
type=streamlit_util.IMAGE_EXTENSIONS,
|
||||||
label_visibility="collapsed",
|
label_visibility="collapsed",
|
||||||
)
|
)
|
||||||
if not image_file:
|
if not image_file:
|
||||||
|
@ -55,13 +57,17 @@ def render_image_to_audio() -> None:
|
||||||
with st.expander("Spectrogram Parameters", expanded=False):
|
with st.expander("Spectrogram Parameters", expanded=False):
|
||||||
st.json(dataclasses.asdict(params))
|
st.json(dataclasses.asdict(params))
|
||||||
|
|
||||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||||
image=image.copy(),
|
image=image.copy(),
|
||||||
params=params,
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
output_format="mp3",
|
|
||||||
)
|
)
|
||||||
st.audio(audio_bytes)
|
|
||||||
|
streamlit_util.display_and_download_audio(
|
||||||
|
segment,
|
||||||
|
name=Path(image_file.name).stem,
|
||||||
|
extension=extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -42,6 +42,7 @@ def render_interpolation() -> None:
|
||||||
# Sidebar params
|
# Sidebar params
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
|
|
||||||
num_interpolation_steps = T.cast(
|
num_interpolation_steps = T.cast(
|
||||||
int,
|
int,
|
||||||
|
@ -78,7 +79,7 @@ def render_interpolation() -> None:
|
||||||
if init_image_name == "custom":
|
if init_image_name == "custom":
|
||||||
init_image_file = st.sidebar.file_uploader(
|
init_image_file = st.sidebar.file_uploader(
|
||||||
"Upload a custom seed image",
|
"Upload a custom seed image",
|
||||||
type=["png", "jpg", "jpeg"],
|
type=streamlit_util.IMAGE_EXTENSIONS,
|
||||||
label_visibility="collapsed",
|
label_visibility="collapsed",
|
||||||
)
|
)
|
||||||
if init_image_file:
|
if init_image_file:
|
||||||
|
@ -154,6 +155,7 @@ def render_interpolation() -> None:
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
init_image=init_image,
|
init_image=init_image,
|
||||||
device=device,
|
device=device,
|
||||||
|
extension=extension,
|
||||||
)
|
)
|
||||||
|
|
||||||
if show_individual_outputs:
|
if show_individual_outputs:
|
||||||
|
@ -167,19 +169,30 @@ def render_interpolation() -> None:
|
||||||
|
|
||||||
st.write("#### Final Output")
|
st.write("#### Final Output")
|
||||||
|
|
||||||
# TODO(hayk): Concatenate with better blending
|
# TODO(hayk): Concatenate with overlap and better blending like in audio to audio
|
||||||
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
|
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
|
||||||
concat_segment = audio_segments[0]
|
concat_segment = audio_segments[0]
|
||||||
for segment in audio_segments[1:]:
|
for segment in audio_segments[1:]:
|
||||||
concat_segment = concat_segment.append(segment, crossfade=0)
|
concat_segment = concat_segment.append(segment, crossfade=0)
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
audio_bytes = io.BytesIO()
|
||||||
concat_segment.export(audio_bytes, format="mp3")
|
concat_segment.export(audio_bytes, format=extension)
|
||||||
audio_bytes.seek(0)
|
audio_bytes.seek(0)
|
||||||
|
|
||||||
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
|
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
|
||||||
st.audio(audio_bytes)
|
st.audio(audio_bytes)
|
||||||
|
|
||||||
|
output_name = (
|
||||||
|
f"{prompt_input_a.prompt.replace(' ', '_')}_"
|
||||||
|
f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
|
||||||
|
)
|
||||||
|
st.download_button(
|
||||||
|
output_name,
|
||||||
|
data=audio_bytes,
|
||||||
|
file_name=output_name,
|
||||||
|
mime=f"audio/{extension}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_inputs(
|
def get_prompt_inputs(
|
||||||
key: str,
|
key: str,
|
||||||
|
@ -222,7 +235,7 @@ def get_prompt_inputs(
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.experimental_memo
|
||||||
def run_interpolation(
|
def run_interpolation(
|
||||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
|
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
||||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||||
"""
|
"""
|
||||||
Cached function for riffusion interpolation.
|
Cached function for riffusion interpolation.
|
||||||
|
@ -250,7 +263,7 @@ def run_interpolation(
|
||||||
image=image,
|
image=image,
|
||||||
params=params,
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
output_format="mp3",
|
output_format=extension,
|
||||||
)
|
)
|
||||||
|
|
||||||
return image, audio_bytes
|
return image, audio_bytes
|
||||||
|
|
|
@ -6,6 +6,8 @@ import numpy as np
|
||||||
import pydub
|
import pydub
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
|
||||||
|
|
||||||
def render_sample_clips() -> None:
|
def render_sample_clips() -> None:
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
st.set_page_config(layout="wide", page_icon="🎸")
|
||||||
|
@ -28,7 +30,7 @@ def render_sample_clips() -> None:
|
||||||
|
|
||||||
audio_file = st.file_uploader(
|
audio_file = st.file_uploader(
|
||||||
"Upload a file",
|
"Upload a file",
|
||||||
type=["wav", "mp3", "ogg"],
|
type=streamlit_util.AUDIO_EXTENSIONS,
|
||||||
label_visibility="collapsed",
|
label_visibility="collapsed",
|
||||||
)
|
)
|
||||||
if not audio_file:
|
if not audio_file:
|
||||||
|
@ -49,22 +51,26 @@ def render_sample_clips() -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000))
|
save_to_disk = st.sidebar.checkbox("Save to Disk", False)
|
||||||
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
|
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
|
||||||
num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3))
|
|
||||||
extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"])
|
|
||||||
assert extension is not None
|
|
||||||
|
|
||||||
# Optionally specify an output directory
|
row = st.columns(4)
|
||||||
output_dir = st.text_input("Output Directory")
|
num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3))
|
||||||
if not output_dir:
|
duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000))
|
||||||
tmp_dir = tempfile.mkdtemp(prefix="sample_clips_")
|
seed = T.cast(int, row[2].number_input("Seed", value=42))
|
||||||
st.info(f"Specify an output directory. Suggested: `{tmp_dir}`")
|
|
||||||
|
counter = streamlit_util.StreamlitCounter()
|
||||||
|
st.button("Sample Clips", type="primary", on_click=counter.increment)
|
||||||
|
if counter.value == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
output_path = Path(output_dir)
|
# Optionally pick an output directory
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
if save_to_disk:
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="sample_clips_")
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
st.info(f"Output directory: `{output_dir}`")
|
||||||
|
|
||||||
if seed >= 0:
|
if seed >= 0:
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
@ -78,16 +84,22 @@ def render_sample_clips() -> None:
|
||||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||||
|
|
||||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms"
|
||||||
|
|
||||||
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
|
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
|
||||||
|
|
||||||
clip_path = output_path / clip_name
|
streamlit_util.display_and_download_audio(
|
||||||
clip.export(clip_path, format=extension)
|
clip,
|
||||||
|
name=clip_name,
|
||||||
|
extension=extension,
|
||||||
|
)
|
||||||
|
|
||||||
st.audio(str(clip_path))
|
if save_to_disk:
|
||||||
|
clip_path = output_path / f"clip_name.{extension}"
|
||||||
|
clip.export(clip_path, format=extension)
|
||||||
|
|
||||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
if save_to_disk:
|
||||||
|
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import io
|
import typing as T
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pydub
|
import pydub
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from riffusion.audio_splitter import split_audio
|
from riffusion.audio_splitter import split_audio
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
from riffusion.util import audio_util
|
||||||
|
|
||||||
|
|
||||||
def render_split_audio() -> None:
|
def render_split_audio() -> None:
|
||||||
|
@ -13,7 +15,7 @@ def render_split_audio() -> None:
|
||||||
st.subheader(":scissors: Audio Splitter")
|
st.subheader(":scissors: Audio Splitter")
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Split an audio into stems of {vocals, drums, bass, piano, guitar, other}.
|
Split audio into individual instrument stems.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,13 +34,21 @@ def render_split_audio() -> None:
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
|
|
||||||
|
extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"]
|
||||||
|
extension = st.sidebar.selectbox(
|
||||||
|
"Output format",
|
||||||
|
options=extension_options,
|
||||||
|
index=extension_options.index("mp3"),
|
||||||
|
)
|
||||||
|
assert extension is not None
|
||||||
|
|
||||||
audio_file = st.file_uploader(
|
audio_file = st.file_uploader(
|
||||||
"Upload audio",
|
"Upload audio",
|
||||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
type=extension_options,
|
||||||
label_visibility="collapsed",
|
label_visibility="collapsed",
|
||||||
)
|
)
|
||||||
|
|
||||||
stem_options = ["vocals", "drums", "bass", "guitar", "piano", "other"]
|
stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
|
||||||
recombine = st.sidebar.multiselect(
|
recombine = st.sidebar.multiselect(
|
||||||
"Recombine",
|
"Recombine",
|
||||||
options=stem_options,
|
options=stem_options,
|
||||||
|
@ -50,39 +60,45 @@ def render_split_audio() -> None:
|
||||||
st.info("Upload audio to get started")
|
st.info("Upload audio to get started")
|
||||||
return
|
return
|
||||||
|
|
||||||
st.write("#### original")
|
st.write("#### Original")
|
||||||
# TODO(hayk): This might be bogus, it can be other formats..
|
st.audio(audio_file)
|
||||||
st.audio(audio_file, format="audio/mp3")
|
|
||||||
|
|
||||||
if not st.button("Split", type="primary"):
|
counter = streamlit_util.StreamlitCounter()
|
||||||
|
st.button("Split", type="primary", on_click=counter.increment)
|
||||||
|
if counter.value == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
segment = streamlit_util.load_audio_file(audio_file)
|
segment = streamlit_util.load_audio_file(audio_file)
|
||||||
|
|
||||||
# Split
|
# Split
|
||||||
stems = split_audio(segment, device=device)
|
stems = split_audio_cached(segment, device=device)
|
||||||
|
|
||||||
|
input_name = Path(audio_file.name).stem
|
||||||
|
|
||||||
# Display each
|
# Display each
|
||||||
for name, stem in stems.items():
|
for name in stem_options:
|
||||||
st.write(f"#### {name}")
|
stem = stems[name.lower()]
|
||||||
audio_bytes = io.BytesIO()
|
st.write(f"#### Stem: {name}")
|
||||||
stem.export(audio_bytes, format="mp3")
|
|
||||||
st.audio(audio_bytes, format="audio/mp3")
|
output_name = f"{input_name}_{name.lower()}"
|
||||||
|
streamlit_util.display_and_download_audio(stem, output_name, extension=extension)
|
||||||
|
|
||||||
if recombine:
|
if recombine:
|
||||||
recombined: pydub.AudioSegment = None
|
recombine_lower = [r.lower() for r in recombine]
|
||||||
for name, stem in stems.items():
|
segments = [s for name, s in stems.items() if name in recombine_lower]
|
||||||
if name in recombine:
|
recombined = audio_util.overlay_segments(segments)
|
||||||
if recombined is None:
|
|
||||||
recombined = stem
|
|
||||||
else:
|
|
||||||
recombined = recombined.overlay(stem)
|
|
||||||
|
|
||||||
# Display
|
# Display
|
||||||
st.write("#### recombined")
|
st.write(f"#### Recombined: {', '.join(recombine)}")
|
||||||
audio_bytes = io.BytesIO()
|
output_name = f"{input_name}_{'_'.join(recombine_lower)}"
|
||||||
recombined.export(audio_bytes, format="mp3")
|
streamlit_util.display_and_download_audio(recombined, output_name, extension=extension)
|
||||||
st.audio(audio_bytes, format="audio/mp3")
|
|
||||||
|
|
||||||
|
@st.cache
|
||||||
|
def split_audio_cached(
|
||||||
|
segment: pydub.AudioSegment, device: str = "cuda"
|
||||||
|
) -> T.Dict[str, pydub.AudioSegment]:
|
||||||
|
return split_audio(segment, device=device)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -27,6 +27,7 @@ def render_text_to_audio() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
|
|
||||||
with st.form("Inputs"):
|
with st.form("Inputs"):
|
||||||
prompt = st.text_input("Prompt")
|
prompt = st.text_input("Prompt")
|
||||||
|
@ -87,13 +88,15 @@ def render_text_to_audio() -> None:
|
||||||
)
|
)
|
||||||
st.image(image)
|
st.image(image)
|
||||||
|
|
||||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||||
image=image,
|
image=image,
|
||||||
params=params,
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
output_format="mp3",
|
|
||||||
)
|
)
|
||||||
st.audio(audio_bytes)
|
|
||||||
|
streamlit_util.display_and_download_audio(
|
||||||
|
segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension
|
||||||
|
)
|
||||||
|
|
||||||
seed += 1
|
seed += 1
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,13 @@ def render_main():
|
||||||
st.write("Generate audio clips from text prompts.")
|
st.write("Generate audio clips from text prompts.")
|
||||||
|
|
||||||
create_link(":wave: Audio to Audio", "/audio_to_audio")
|
create_link(":wave: Audio to Audio", "/audio_to_audio")
|
||||||
st.write("Upload audio and modify with text prompt.")
|
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
||||||
|
|
||||||
create_link(":performing_arts: Interpolation", "/interpolation")
|
create_link(":performing_arts: Interpolation", "/interpolation")
|
||||||
st.write("Interpolate between prompts in the latent space.")
|
st.write("Interpolate between prompts in the latent space.")
|
||||||
|
|
||||||
create_link(":scissors: Audio Splitter", "/split_audio")
|
create_link(":scissors: Audio Splitter", "/split_audio")
|
||||||
st.write("Upload audio and split into vocals, bass, drums, and other.")
|
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
||||||
|
|
||||||
with right:
|
with right:
|
||||||
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
||||||
|
|
|
@ -17,6 +17,9 @@ from riffusion.spectrogram_params import SpectrogramParams
|
||||||
|
|
||||||
# TODO(hayk): Add URL params
|
# TODO(hayk): Add URL params
|
||||||
|
|
||||||
|
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
||||||
|
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.experimental_singleton
|
||||||
def load_riffusion_checkpoint(
|
def load_riffusion_checkpoint(
|
||||||
|
@ -177,6 +180,20 @@ def select_device(container: T.Any = st.sidebar) -> str:
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def select_audio_extension(container: T.Any = st.sidebar) -> str:
|
||||||
|
"""
|
||||||
|
Dropdown to select an audio extension, with an intelligent default.
|
||||||
|
"""
|
||||||
|
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
|
||||||
|
extension = container.selectbox(
|
||||||
|
"Output format",
|
||||||
|
options=AUDIO_EXTENSIONS,
|
||||||
|
index=AUDIO_EXTENSIONS.index(default),
|
||||||
|
)
|
||||||
|
assert extension is not None
|
||||||
|
return extension
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.experimental_memo
|
||||||
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
||||||
return pydub.AudioSegment.from_file(audio_file)
|
return pydub.AudioSegment.from_file(audio_file)
|
||||||
|
@ -224,3 +241,43 @@ def run_img2img(
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamlitCounter:
|
||||||
|
"""
|
||||||
|
Simple counter stored in streamlit session state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key="_counter"):
|
||||||
|
self.key = key
|
||||||
|
if not st.session_state.get(self.key):
|
||||||
|
st.session_state[self.key] = 0
|
||||||
|
|
||||||
|
def increment(self):
|
||||||
|
st.session_state[self.key] += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return st.session_state[self.key]
|
||||||
|
|
||||||
|
|
||||||
|
def display_and_download_audio(
|
||||||
|
segment: pydub.AudioSegment,
|
||||||
|
name: str,
|
||||||
|
extension: str = "mp3",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Display the given audio segment and provide a button to download it with
|
||||||
|
a proper file name, since st.audio doesn't support that.
|
||||||
|
"""
|
||||||
|
mime_type = f"audio/{extension}"
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
segment.export(audio_bytes, format=extension)
|
||||||
|
st.audio(audio_bytes, format=mime_type)
|
||||||
|
|
||||||
|
st.download_button(
|
||||||
|
f"{name}.{extension}",
|
||||||
|
data=audio_bytes,
|
||||||
|
file_name=f"{name}.{extension}",
|
||||||
|
mime=mime_type,
|
||||||
|
)
|
||||||
|
|
|
@ -83,3 +83,17 @@ def stitch_segments(
|
||||||
for segment in segments[1:]:
|
for segment in segments[1:]:
|
||||||
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
||||||
return combined_segment
|
return combined_segment
|
||||||
|
|
||||||
|
|
||||||
|
def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment:
|
||||||
|
"""
|
||||||
|
Overlay a sequence of audio segments on top of each other.
|
||||||
|
"""
|
||||||
|
assert len(segments) > 0
|
||||||
|
output: pydub.AudioSegment = None
|
||||||
|
for segment in segments:
|
||||||
|
if output is None:
|
||||||
|
output = segment
|
||||||
|
else:
|
||||||
|
output = output.overlay(segment)
|
||||||
|
return output
|
||||||
|
|
Loading…
Reference in New Issue