diff --git a/riffusion/streamlit/pages/audio_to_audio.py b/riffusion/streamlit/pages/audio_to_audio.py index 71062b0..35067c6 100644 --- a/riffusion/streamlit/pages/audio_to_audio.py +++ b/riffusion/streamlit/pages/audio_to_audio.py @@ -1,5 +1,6 @@ import io import typing as T +from pathlib import Path import numpy as np import pydub @@ -19,7 +20,7 @@ def render_audio_to_audio() -> None: st.subheader(":wave: Audio to Audio") st.write( """ - Modify existing audio from a text prompt. + Modify existing audio from a text prompt or interpolate between two. """ ) @@ -35,10 +36,15 @@ def render_audio_to_audio() -> None: modification. The best specific denoising depends on how different the prompt is from the source audio. You can play with the seed to get infinite variations. Currently the same seed is used for all clips along the track. + + If the Interpolation check box is enabled, supports entering two sets of prompt, + seed, and denoising value and smoothly blends between them along the selected + duration of the audio. This is a great way to create a transition. """ ) device = streamlit_util.select_device(st.sidebar) + extension = streamlit_util.select_audio_extension(st.sidebar) num_inference_steps = T.cast( int, @@ -55,7 +61,7 @@ def render_audio_to_audio() -> None: audio_file = st.file_uploader( "Upload audio", - type=["mp3", "m4a", "ogg", "wav", "flac", "webm"], + type=streamlit_util.AUDIO_EXTENSIONS, label_visibility="collapsed", ) @@ -89,7 +95,14 @@ def render_audio_to_audio() -> None: overlap_duration_s=overlap_duration_s, ) - interpolate = st.checkbox("Interpolate between two settings", False) + interpolate = st.checkbox( + "Interpolate between two endpoints", + value=False, + help="Interpolate between two prompts, seeds, or denoising values along the" + "duration of the segment", + ) + + counter = streamlit_util.StreamlitCounter() with st.form("audio to audio form"): if interpolate: @@ -109,7 +122,7 @@ def render_audio_to_audio() -> None: **get_prompt_inputs(key="a", include_negative_prompt=True, cols=True), ) - submit_button = st.form_submit_button("Riff", type="primary") + st.form_submit_button("Riff", type="primary", on_click=counter.increment) show_clip_details = st.sidebar.checkbox("Show Clip Details", True) show_difference = st.sidebar.checkbox("Show Difference", False) @@ -124,9 +137,11 @@ def render_audio_to_audio() -> None: st.info("Enter a prompt") return - if not submit_button: + if counter.value == 0: return + st.write(f"## Counter: {counter.value}") + params = SpectrogramParams() if interpolate: @@ -228,16 +243,17 @@ def render_audio_to_audio() -> None: ) audio_bytes = io.BytesIO() - diff_segment.export(audio_bytes, format="wav") + diff_segment.export(audio_bytes, format=extension) st.audio(audio_bytes) # Combine clips with a crossfade based on overlap combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s) - audio_bytes = io.BytesIO() - combined_segment.export(audio_bytes, format="mp3") st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)") - st.audio(audio_bytes, format="audio/mp3") + + input_name = Path(audio_file.name).stem + output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}" + streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension) def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]: diff --git a/riffusion/streamlit/pages/image_to_audio.py b/riffusion/streamlit/pages/image_to_audio.py index 6736a93..d6ba406 100644 --- a/riffusion/streamlit/pages/image_to_audio.py +++ b/riffusion/streamlit/pages/image_to_audio.py @@ -1,4 +1,5 @@ import dataclasses +from pathlib import Path import streamlit as st from PIL import Image @@ -29,10 +30,11 @@ def render_image_to_audio() -> None: ) device = streamlit_util.select_device(st.sidebar) + extension = streamlit_util.select_audio_extension(st.sidebar) image_file = st.file_uploader( "Upload a file", - type=["png", "jpg", "jpeg"], + type=streamlit_util.IMAGE_EXTENSIONS, label_visibility="collapsed", ) if not image_file: @@ -55,13 +57,17 @@ def render_image_to_audio() -> None: with st.expander("Spectrogram Parameters", expanded=False): st.json(dataclasses.asdict(params)) - audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + segment = streamlit_util.audio_segment_from_spectrogram_image( image=image.copy(), params=params, device=device, - output_format="mp3", ) - st.audio(audio_bytes) + + streamlit_util.display_and_download_audio( + segment, + name=Path(image_file.name).stem, + extension=extension, + ) if __name__ == "__main__": diff --git a/riffusion/streamlit/pages/interpolation.py b/riffusion/streamlit/pages/interpolation.py index 9cde099..d58e3dc 100644 --- a/riffusion/streamlit/pages/interpolation.py +++ b/riffusion/streamlit/pages/interpolation.py @@ -42,6 +42,7 @@ def render_interpolation() -> None: # Sidebar params device = streamlit_util.select_device(st.sidebar) + extension = streamlit_util.select_audio_extension(st.sidebar) num_interpolation_steps = T.cast( int, @@ -78,7 +79,7 @@ def render_interpolation() -> None: if init_image_name == "custom": init_image_file = st.sidebar.file_uploader( "Upload a custom seed image", - type=["png", "jpg", "jpeg"], + type=streamlit_util.IMAGE_EXTENSIONS, label_visibility="collapsed", ) if init_image_file: @@ -154,6 +155,7 @@ def render_interpolation() -> None: inputs=inputs, init_image=init_image, device=device, + extension=extension, ) if show_individual_outputs: @@ -167,19 +169,30 @@ def render_interpolation() -> None: st.write("#### Final Output") - # TODO(hayk): Concatenate with better blending + # TODO(hayk): Concatenate with overlap and better blending like in audio to audio audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list] concat_segment = audio_segments[0] for segment in audio_segments[1:]: concat_segment = concat_segment.append(segment, crossfade=0) audio_bytes = io.BytesIO() - concat_segment.export(audio_bytes, format="mp3") + concat_segment.export(audio_bytes, format=extension) audio_bytes.seek(0) st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds") st.audio(audio_bytes) + output_name = ( + f"{prompt_input_a.prompt.replace(' ', '_')}_" + f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}" + ) + st.download_button( + output_name, + data=audio_bytes, + file_name=output_name, + mime=f"audio/{extension}", + ) + def get_prompt_inputs( key: str, @@ -222,7 +235,7 @@ def get_prompt_inputs( @st.experimental_memo def run_interpolation( - inputs: InferenceInput, init_image: Image.Image, device: str = "cuda" + inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3" ) -> T.Tuple[Image.Image, io.BytesIO]: """ Cached function for riffusion interpolation. @@ -250,7 +263,7 @@ def run_interpolation( image=image, params=params, device=device, - output_format="mp3", + output_format=extension, ) return image, audio_bytes diff --git a/riffusion/streamlit/pages/sample_clips.py b/riffusion/streamlit/pages/sample_clips.py index 5fd01ea..43df1ff 100644 --- a/riffusion/streamlit/pages/sample_clips.py +++ b/riffusion/streamlit/pages/sample_clips.py @@ -6,6 +6,8 @@ import numpy as np import pydub import streamlit as st +from riffusion.streamlit import util as streamlit_util + def render_sample_clips() -> None: st.set_page_config(layout="wide", page_icon="🎸") @@ -28,7 +30,7 @@ def render_sample_clips() -> None: audio_file = st.file_uploader( "Upload a file", - type=["wav", "mp3", "ogg"], + type=streamlit_util.AUDIO_EXTENSIONS, label_visibility="collapsed", ) if not audio_file: @@ -49,22 +51,26 @@ def render_sample_clips() -> None: ) ) - seed = T.cast(int, st.sidebar.number_input("Seed", value=42)) - duration_ms = T.cast(int, st.sidebar.number_input("Duration (ms)", value=5000)) + extension = streamlit_util.select_audio_extension(st.sidebar) + save_to_disk = st.sidebar.checkbox("Save to Disk", False) export_as_mono = st.sidebar.checkbox("Export as Mono", False) - num_clips = T.cast(int, st.sidebar.number_input("Number of Clips", value=3)) - extension = st.sidebar.selectbox("Extension", ["mp3", "wav", "ogg"]) - assert extension is not None - # Optionally specify an output directory - output_dir = st.text_input("Output Directory") - if not output_dir: - tmp_dir = tempfile.mkdtemp(prefix="sample_clips_") - st.info(f"Specify an output directory. Suggested: `{tmp_dir}`") + row = st.columns(4) + num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3)) + duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000)) + seed = T.cast(int, row[2].number_input("Seed", value=42)) + + counter = streamlit_util.StreamlitCounter() + st.button("Sample Clips", type="primary", on_click=counter.increment) + if counter.value == 0: return - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) + # Optionally pick an output directory + if save_to_disk: + output_dir = tempfile.mkdtemp(prefix="sample_clips_") + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + st.info(f"Output directory: `{output_dir}`") if seed >= 0: np.random.seed(seed) @@ -78,16 +84,22 @@ def render_sample_clips() -> None: clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) clip = segment[clip_start_ms : clip_start_ms + duration_ms] - clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}" + clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms" st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`") - clip_path = output_path / clip_name - clip.export(clip_path, format=extension) + streamlit_util.display_and_download_audio( + clip, + name=clip_name, + extension=extension, + ) - st.audio(str(clip_path)) + if save_to_disk: + clip_path = output_path / f"clip_name.{extension}" + clip.export(clip_path, format=extension) - st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`") + if save_to_disk: + st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`") if __name__ == "__main__": diff --git a/riffusion/streamlit/pages/split_audio.py b/riffusion/streamlit/pages/split_audio.py index a69e1b2..69720a5 100644 --- a/riffusion/streamlit/pages/split_audio.py +++ b/riffusion/streamlit/pages/split_audio.py @@ -1,10 +1,12 @@ -import io +import typing as T +from pathlib import Path import pydub import streamlit as st from riffusion.audio_splitter import split_audio from riffusion.streamlit import util as streamlit_util +from riffusion.util import audio_util def render_split_audio() -> None: @@ -13,7 +15,7 @@ def render_split_audio() -> None: st.subheader(":scissors: Audio Splitter") st.write( """ - Split an audio into stems of {vocals, drums, bass, piano, guitar, other}. + Split audio into individual instrument stems. """ ) @@ -32,13 +34,21 @@ def render_split_audio() -> None: device = streamlit_util.select_device(st.sidebar) + extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"] + extension = st.sidebar.selectbox( + "Output format", + options=extension_options, + index=extension_options.index("mp3"), + ) + assert extension is not None + audio_file = st.file_uploader( "Upload audio", - type=["mp3", "m4a", "ogg", "wav", "flac", "webm"], + type=extension_options, label_visibility="collapsed", ) - stem_options = ["vocals", "drums", "bass", "guitar", "piano", "other"] + stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"] recombine = st.sidebar.multiselect( "Recombine", options=stem_options, @@ -50,39 +60,45 @@ def render_split_audio() -> None: st.info("Upload audio to get started") return - st.write("#### original") - # TODO(hayk): This might be bogus, it can be other formats.. - st.audio(audio_file, format="audio/mp3") + st.write("#### Original") + st.audio(audio_file) - if not st.button("Split", type="primary"): + counter = streamlit_util.StreamlitCounter() + st.button("Split", type="primary", on_click=counter.increment) + if counter.value == 0: return segment = streamlit_util.load_audio_file(audio_file) # Split - stems = split_audio(segment, device=device) + stems = split_audio_cached(segment, device=device) + + input_name = Path(audio_file.name).stem # Display each - for name, stem in stems.items(): - st.write(f"#### {name}") - audio_bytes = io.BytesIO() - stem.export(audio_bytes, format="mp3") - st.audio(audio_bytes, format="audio/mp3") + for name in stem_options: + stem = stems[name.lower()] + st.write(f"#### Stem: {name}") + + output_name = f"{input_name}_{name.lower()}" + streamlit_util.display_and_download_audio(stem, output_name, extension=extension) if recombine: - recombined: pydub.AudioSegment = None - for name, stem in stems.items(): - if name in recombine: - if recombined is None: - recombined = stem - else: - recombined = recombined.overlay(stem) + recombine_lower = [r.lower() for r in recombine] + segments = [s for name, s in stems.items() if name in recombine_lower] + recombined = audio_util.overlay_segments(segments) # Display - st.write("#### recombined") - audio_bytes = io.BytesIO() - recombined.export(audio_bytes, format="mp3") - st.audio(audio_bytes, format="audio/mp3") + st.write(f"#### Recombined: {', '.join(recombine)}") + output_name = f"{input_name}_{'_'.join(recombine_lower)}" + streamlit_util.display_and_download_audio(recombined, output_name, extension=extension) + + +@st.cache +def split_audio_cached( + segment: pydub.AudioSegment, device: str = "cuda" +) -> T.Dict[str, pydub.AudioSegment]: + return split_audio(segment, device=device) if __name__ == "__main__": diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 16aeca2..48037e1 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -27,6 +27,7 @@ def render_text_to_audio() -> None: ) device = streamlit_util.select_device(st.sidebar) + extension = streamlit_util.select_audio_extension(st.sidebar) with st.form("Inputs"): prompt = st.text_input("Prompt") @@ -87,13 +88,15 @@ def render_text_to_audio() -> None: ) st.image(image) - audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + segment = streamlit_util.audio_segment_from_spectrogram_image( image=image, params=params, device=device, - output_format="mp3", ) - st.audio(audio_bytes) + + streamlit_util.display_and_download_audio( + segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension + ) seed += 1 diff --git a/riffusion/streamlit/playground.py b/riffusion/streamlit/playground.py index 36c86ae..f705803 100644 --- a/riffusion/streamlit/playground.py +++ b/riffusion/streamlit/playground.py @@ -13,13 +13,13 @@ def render_main(): st.write("Generate audio clips from text prompts.") create_link(":wave: Audio to Audio", "/audio_to_audio") - st.write("Upload audio and modify with text prompt.") + st.write("Upload audio and modify with text prompt (interpolation supported).") create_link(":performing_arts: Interpolation", "/interpolation") st.write("Interpolate between prompts in the latent space.") create_link(":scissors: Audio Splitter", "/split_audio") - st.write("Upload audio and split into vocals, bass, drums, and other.") + st.write("Split audio into stems like vocals, bass, drums, guitar, etc.") with right: create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch") diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index 7d4edc7..74e276b 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -17,6 +17,9 @@ from riffusion.spectrogram_params import SpectrogramParams # TODO(hayk): Add URL params +AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"] +IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"] + @st.experimental_singleton def load_riffusion_checkpoint( @@ -177,6 +180,20 @@ def select_device(container: T.Any = st.sidebar) -> str: return device +def select_audio_extension(container: T.Any = st.sidebar) -> str: + """ + Dropdown to select an audio extension, with an intelligent default. + """ + default = "mp3" if pydub.AudioSegment.ffmpeg else "wav" + extension = container.selectbox( + "Output format", + options=AUDIO_EXTENSIONS, + index=AUDIO_EXTENSIONS.index(default), + ) + assert extension is not None + return extension + + @st.experimental_memo def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: return pydub.AudioSegment.from_file(audio_file) @@ -224,3 +241,43 @@ def run_img2img( ) return result.images[0] + + +class StreamlitCounter: + """ + Simple counter stored in streamlit session state. + """ + + def __init__(self, key="_counter"): + self.key = key + if not st.session_state.get(self.key): + st.session_state[self.key] = 0 + + def increment(self): + st.session_state[self.key] += 1 + + @property + def value(self): + return st.session_state[self.key] + + +def display_and_download_audio( + segment: pydub.AudioSegment, + name: str, + extension: str = "mp3", +) -> None: + """ + Display the given audio segment and provide a button to download it with + a proper file name, since st.audio doesn't support that. + """ + mime_type = f"audio/{extension}" + audio_bytes = io.BytesIO() + segment.export(audio_bytes, format=extension) + st.audio(audio_bytes, format=mime_type) + + st.download_button( + f"{name}.{extension}", + data=audio_bytes, + file_name=f"{name}.{extension}", + mime=mime_type, + ) diff --git a/riffusion/util/audio_util.py b/riffusion/util/audio_util.py index aa66e7a..999a557 100644 --- a/riffusion/util/audio_util.py +++ b/riffusion/util/audio_util.py @@ -83,3 +83,17 @@ def stitch_segments( for segment in segments[1:]: combined_segment = combined_segment.append(segment, crossfade=crossfade_ms) return combined_segment + + +def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment: + """ + Overlay a sequence of audio segments on top of each other. + """ + assert len(segments) > 0 + output: pydub.AudioSegment = None + for segment in segments: + if output is None: + output = segment + else: + output = output.overlay(segment) + return output