diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 346544d..bbc2aca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,6 +27,7 @@ jobs: - name: Install system packages run: | + sudo apt-get update sudo apt-get install -y ffmpeg libsndfile1 - name: Install pip packages from requirements.txt diff --git a/.gitignore b/.gitignore index 936865c..60dae11 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,9 @@ __pycache__/ # Cog .cog/ +# Random stuff I don't care about +.graveyard/ + # Distribution / packaging .Python build/ diff --git a/riffusion/datatypes.py b/riffusion/datatypes.py index 70c77e0..99f1c2d 100644 --- a/riffusion/datatypes.py +++ b/riffusion/datatypes.py @@ -19,6 +19,9 @@ class PromptInput: # Random seed for denoising seed: int + # Negative prompt to avoid (optional) + negative_prompt: T.Optional[str] = None + # Denoising strength denoising: float = 0.75 diff --git a/riffusion/streamlit/pages/audio_to_audio.py b/riffusion/streamlit/pages/audio_to_audio.py index 4aa1fbc..71062b0 100644 --- a/riffusion/streamlit/pages/audio_to_audio.py +++ b/riffusion/streamlit/pages/audio_to_audio.py @@ -6,8 +6,11 @@ import pydub import streamlit as st from PIL import Image +from riffusion.datatypes import InferenceInput, PromptInput from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util +from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation +from riffusion.util import audio_util def render_audio_to_audio() -> None: @@ -37,6 +40,19 @@ def render_audio_to_audio() -> None: device = streamlit_util.select_device(st.sidebar) + num_inference_steps = T.cast( + int, + st.sidebar.number_input( + "Steps per sample", value=50, help="Number of denoising steps per model run" + ), + ) + + guidance = st.sidebar.number_input( + "Guidance", + value=7.0, + help="How much the model listens to the text prompt", + ) + audio_file = st.file_uploader( "Upload audio", type=["mp3", "m4a", "ogg", "wav", "flac", "webm"], @@ -53,113 +69,58 @@ def render_audio_to_audio() -> None: segment = streamlit_util.load_audio_file(audio_file) # TODO(hayk): Fix - segment = segment.set_frame_rate(44100) + if segment.frame_rate != 44100: + st.warning("Audio must be 44100Hz. Converting") + segment = segment.set_frame_rate(44100) st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz") - if "counter" not in st.session_state: - st.session_state.counter = 0 + clip_p = get_clip_params() + start_time_s = clip_p["start_time_s"] + clip_duration_s = clip_p["clip_duration_s"] + overlap_duration_s = clip_p["overlap_duration_s"] - def increment_counter(): - st.session_state.counter += 1 - - cols = st.columns(4) - start_time_s = cols[0].number_input( - "Start Time [s]", - min_value=0.0, - value=0.0, - ) - duration_s = cols[1].number_input( - "Duration [s]", - min_value=0.0, - value=15.0, - ) - clip_duration_s = cols[2].number_input( - "Clip Duration [s]", - min_value=3.0, - max_value=10.0, - value=5.0, - ) - overlap_duration_s = cols[3].number_input( - "Overlap Duration [s]", - min_value=0.0, - max_value=10.0, - value=0.2, - ) - - duration_s = min(duration_s, segment.duration_seconds - start_time_s) + duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s) increment_s = clip_duration_s - overlap_duration_s clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s) - st.write( - f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s " - f"with overlap {overlap_duration_s}s." + + write_clip_details( + clip_start_times=clip_start_times, + clip_duration_s=clip_duration_s, + overlap_duration_s=overlap_duration_s, ) - with st.expander("Clip Times"): - st.dataframe( - { - "Start Time [s]": clip_start_times, - "End Time [s]": clip_start_times + clip_duration_s, - "Duration [s]": clip_duration_s, - } - ) + interpolate = st.checkbox("Interpolate between two settings", False) - with st.form("Conversion Params"): + with st.form("audio to audio form"): + if interpolate: + left, right = st.columns(2) - prompt = st.text_input("Text Prompt") - negative_prompt = st.text_input("Negative Prompt") + with left: + st.write("##### Prompt A") + prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a")) - cols = st.columns(4) - denoising_strength = cols[0].number_input( - "Denoising Strength", - min_value=0.0, - max_value=1.0, - value=0.45, - ) - guidance_scale = cols[1].number_input( - "Guidance Scale", - min_value=0.0, - max_value=20.0, - value=7.0, - ) - num_inference_steps = int( - cols[2].number_input( - "Num Inference Steps", - min_value=1, - max_value=150, - value=50, + with right: + st.write("##### Prompt B") + prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b")) + + else: + prompt_input_a = PromptInput( + guidance=guidance, + **get_prompt_inputs(key="a", include_negative_prompt=True, cols=True), ) - ) - seed = int( - cols[3].number_input( - "Seed", - min_value=0, - value=42, - ) - ) - - submit_button = st.form_submit_button("Convert", on_click=increment_counter) - - # TODO fix + submit_button = st.form_submit_button("Riff", type="primary") show_clip_details = st.sidebar.checkbox("Show Clip Details", True) show_difference = st.sidebar.checkbox("Show Difference", False) - clip_segments: T.List[pydub.AudioSegment] = [] - for i, clip_start_time_s in enumerate(clip_start_times): - clip_start_time_ms = int(clip_start_time_s * 1000) - clip_duration_ms = int(clip_duration_s * 1000) - clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms] + clip_segments = slice_audio_into_clips( + segment=segment, + clip_start_times=clip_start_times, + clip_duration_s=clip_duration_s, + ) - # TODO(hayk): I don't think this is working properly - if i == len(clip_start_times) - 1: - silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000) - if silence_ms > 0: - clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms)) - - clip_segments.append(clip_segment) - - if not prompt: + if not prompt_input_a.prompt: st.info("Enter a prompt") return @@ -168,10 +129,16 @@ def render_audio_to_audio() -> None: params = SpectrogramParams() + if interpolate: + # TODO(hayk): Make not linspace + alphas = list(np.linspace(0, 1, len(clip_segments))) + alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) + st.write(f"**Alphas** : [{alphas_str}]") + result_images: T.List[Image.Image] = [] result_segments: T.List[pydub.AudioSegment] = [] for i, clip_segment in enumerate(clip_segments): - st.write(f"### Clip {i} at {clip_start_times[i]}s") + st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s") audio_bytes = io.BytesIO() clip_segment.export(audio_bytes, format="wav") @@ -183,10 +150,7 @@ def render_audio_to_audio() -> None: ) # TODO(hayk): Roll this into spectrogram_image_from_audio? - # TODO(hayk): Scale something when computing audio - closest_width = int(np.ceil(init_image.width / 32) * 32) - closest_height = int(np.ceil(init_image.height / 32) * 32) - init_image_resized = init_image.resize((closest_width, closest_height), Image.BICUBIC) + init_image_resized = scale_image_to_32_stride(init_image) progress_callback = None if show_clip_details: @@ -203,17 +167,32 @@ def render_audio_to_audio() -> None: progress = st.progress(0.0) progress_callback = progress.progress - image = streamlit_util.run_img2img( - prompt=prompt, - init_image=init_image_resized, - denoising_strength=denoising_strength, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt, - seed=seed, - progress_callback=progress_callback, - device=device, - ) + if interpolate: + inputs = InferenceInput( + alpha=float(alphas[i]), + num_inference_steps=num_inference_steps, + seed_image_id="og_beat", + start=prompt_input_a, + end=prompt_input_b, + ) + + image, audio_bytes = run_interpolation( + inputs=inputs, + init_image=init_image_resized, + device=device, + ) + else: + image = streamlit_util.run_img2img( + prompt=prompt_input_a.prompt, + init_image=init_image_resized, + denoising_strength=prompt_input_a.denoising, + num_inference_steps=num_inference_steps, + guidance_scale=guidance, + negative_prompt=prompt_input_a.negative_prompt, + seed=prompt_input_a.seed, + progress_callback=progress_callback, + device=device, + ) # Resize back to original size image = image.resize(init_image.size, Image.BICUBIC) @@ -253,10 +232,7 @@ def render_audio_to_audio() -> None: st.audio(audio_bytes) # Combine clips with a crossfade based on overlap - crossfade_ms = int(overlap_duration_s * 1000) - combined_segment = result_segments[0] - for segment in result_segments[1:]: - combined_segment = combined_segment.append(segment, crossfade=crossfade_ms) + combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s) audio_bytes = io.BytesIO() combined_segment.export(audio_bytes, format="mp3") @@ -264,11 +240,99 @@ def render_audio_to_audio() -> None: st.audio(audio_bytes, format="audio/mp3") -@st.cache -def test(segment: pydub.AudioSegment, counter: int) -> int: - st.write("#### Trimmed") - st.write(segment.duration_seconds) - return counter +def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]: + """ + Render the parameters of slicing audio into clips. + """ + p: T.Dict[str, T.Any] = {} + + cols = st.columns(4) + + p["start_time_s"] = cols[0].number_input( + "Start Time [s]", + min_value=0.0, + value=0.0, + ) + p["duration_s"] = cols[1].number_input( + "Duration [s]", + min_value=0.0, + value=15.0, + ) + + if advanced: + p["clip_duration_s"] = cols[2].number_input( + "Clip Duration [s]", + min_value=3.0, + max_value=10.0, + value=5.0, + ) + else: + p["clip_duration_s"] = 5.0 + + if advanced: + p["overlap_duration_s"] = cols[3].number_input( + "Overlap Duration [s]", + min_value=0.0, + max_value=10.0, + value=0.2, + ) + else: + p["overlap_duration_s"] = 0.2 + + return p + + +def write_clip_details( + clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float +): + """ + Write details of the clips to be sliced from an audio segment. + """ + clip_details_text = ( + f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s " + f"with overlap {overlap_duration_s}s" + ) + + with st.expander(clip_details_text): + st.dataframe( + { + "Start Time [s]": clip_start_times, + "End Time [s]": clip_start_times + clip_duration_s, + "Duration [s]": clip_duration_s, + } + ) + + +def slice_audio_into_clips( + segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float +) -> T.List[pydub.AudioSegment]: + """ + Slice an audio segment into a list of clips of a given duration at the given start times. + """ + clip_segments: T.List[pydub.AudioSegment] = [] + for i, clip_start_time_s in enumerate(clip_start_times): + clip_start_time_ms = int(clip_start_time_s * 1000) + clip_duration_ms = int(clip_duration_s * 1000) + clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms] + + # TODO(hayk): I don't think this is working properly + if i == len(clip_start_times) - 1: + silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000) + if silence_ms > 0: + clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms)) + + clip_segments.append(clip_segment) + + return clip_segments + + +def scale_image_to_32_stride(image: Image.Image) -> Image.Image: + """ + Scale an image to a size that is a multiple of 32. + """ + closest_width = int(np.ceil(image.width / 32) * 32) + closest_height = int(np.ceil(image.height / 32) * 32) + return image.resize((closest_width, closest_height), Image.BICUBIC) if __name__ == "__main__": diff --git a/riffusion/streamlit/pages/audio_to_audio_interpolate.py b/riffusion/streamlit/pages/audio_to_audio_interpolate.py deleted file mode 100644 index 30db96a..0000000 --- a/riffusion/streamlit/pages/audio_to_audio_interpolate.py +++ /dev/null @@ -1,247 +0,0 @@ -import io -import typing as T - -import numpy as np -import pydub -import streamlit as st -from PIL import Image - -from riffusion.datatypes import InferenceInput -from riffusion.spectrogram_params import SpectrogramParams -from riffusion.streamlit import util as streamlit_util -from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation - - -def render_audio_to_audio_interpolate() -> None: - st.set_page_config(layout="wide", page_icon="🎸") - - st.subheader(":wave: Audio to Audio Inteprolation") - st.write( - """ - Audio to audio with interpolation. - """ - ) - - with st.expander("Help", False): - st.write( - """ - TODO - """ - ) - - device = streamlit_util.select_device(st.sidebar) - - num_inference_steps = T.cast( - int, - st.sidebar.number_input( - "Steps per sample", value=50, help="Number of denoising steps per model run" - ), - ) - - audio_file = st.file_uploader( - "Upload audio", - type=["mp3", "m4a", "ogg", "wav", "flac", "webm"], - label_visibility="collapsed", - ) - - if not audio_file: - st.info("Upload audio to get started") - return - - st.write("#### Original") - st.audio(audio_file) - - segment = streamlit_util.load_audio_file(audio_file) - - # TODO(hayk): Fix - segment = segment.set_frame_rate(44100) - st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz") - - if "counter" not in st.session_state: - st.session_state.counter = 0 - - def increment_counter(): - st.session_state.counter += 1 - - cols = st.columns(4) - start_time_s = cols[0].number_input( - "Start Time [s]", - min_value=0.0, - value=0.0, - ) - duration_s = cols[1].number_input( - "Duration [s]", - min_value=0.0, - value=15.0, - ) - clip_duration_s = cols[2].number_input( - "Clip Duration [s]", - min_value=3.0, - max_value=10.0, - value=5.0, - ) - overlap_duration_s = cols[3].number_input( - "Overlap Duration [s]", - min_value=0.0, - max_value=10.0, - value=0.2, - ) - - duration_s = min(duration_s, segment.duration_seconds - start_time_s) - increment_s = clip_duration_s - overlap_duration_s - clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s) - st.write( - f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s " - f"with overlap {overlap_duration_s}s." - ) - - with st.expander("Clip Times"): - st.dataframe( - { - "Start Time [s]": clip_start_times, - "End Time [s]": clip_start_times + clip_duration_s, - "Duration [s]": clip_duration_s, - } - ) - - with st.form(key="interpolation_form"): - left, right = st.columns(2) - - with left: - st.write("##### Prompt A") - prompt_input_a = get_prompt_inputs(key="a") - - with right: - st.write("##### Prompt B") - prompt_input_b = get_prompt_inputs(key="b") - - submit_button = st.form_submit_button("Generate", type="primary") - - show_clip_details = st.sidebar.checkbox("Show Clip Details", True) - show_difference = st.sidebar.checkbox("Show Difference", False) - - clip_segments: T.List[pydub.AudioSegment] = [] - for i, clip_start_time_s in enumerate(clip_start_times): - clip_start_time_ms = int(clip_start_time_s * 1000) - clip_duration_ms = int(clip_duration_s * 1000) - clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms] - - # TODO(hayk): I don't think this is working properly - if i == len(clip_start_times) - 1: - silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000) - if silence_ms > 0: - clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms)) - - clip_segments.append(clip_segment) - - if not prompt_input_a.prompt or not prompt_input_b.prompt: - st.info("Enter both prompts to interpolate between them") - return - - if not submit_button: - return - - params = SpectrogramParams() - - # TODO(hayk): Make not linspace - alphas = list(np.linspace(0, 1, len(clip_segments))) - alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) - st.write(f"**Alphas** : [{alphas_str}]") - - result_images: T.List[Image.Image] = [] - result_segments: T.List[pydub.AudioSegment] = [] - for i, clip_segment in enumerate(clip_segments): - st.write(f"### Clip {i} at {clip_start_times[i]}s") - - audio_bytes = io.BytesIO() - clip_segment.export(audio_bytes, format="wav") - - init_image = streamlit_util.spectrogram_image_from_audio( - clip_segment, - params=params, - device=device, - ) - - # TODO(hayk): Roll this into spectrogram_image_from_audio? - # TODO(hayk): Scale something when computing audio - closest_width = int(np.ceil(init_image.width / 32) * 32) - closest_height = int(np.ceil(init_image.height / 32) * 32) - init_image_resized = init_image.resize((closest_width, closest_height), Image.BICUBIC) - - # progress_callback = None - if show_clip_details: - left, right = st.columns(2) - - left.write("##### Source Clip") - left.image(init_image, use_column_width=False) - left.audio(audio_bytes) - - right.write("##### Riffed Clip") - empty_bin = right.empty() - with empty_bin.container(): - st.info("Riffing...") - # progress = st.progress(0.0) - # progress_callback = progress.progress - - inputs = InferenceInput( - alpha=float(alphas[i]), - num_inference_steps=num_inference_steps, - seed_image_id="og_beat", - start=prompt_input_a, - end=prompt_input_b, - ) - - image, audio_bytes = run_interpolation( - inputs=inputs, - init_image=init_image_resized, - device=device, - ) - - # Resize back to original size - image = image.resize(init_image.size, Image.BICUBIC) - - result_images.append(image) - - if show_clip_details: - empty_bin.empty() - right.image(image, use_column_width=False) - - riffed_segment = streamlit_util.audio_segment_from_spectrogram_image( - image=image, - params=params, - device=device, - ) - result_segments.append(riffed_segment) - - if show_clip_details: - right.audio(audio_bytes) - - if show_clip_details and show_difference: - diff_np = np.maximum( - 0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32) - ) - diff_image = Image.fromarray(255 - diff_np.astype(np.uint8)) - diff_segment = streamlit_util.audio_segment_from_spectrogram_image( - image=diff_image, - params=params, - device=device, - ) - - audio_bytes = io.BytesIO() - diff_segment.export(audio_bytes, format="wav") - st.audio(audio_bytes) - - # Combine clips with a crossfade based on overlap - crossfade_ms = int(overlap_duration_s * 1000) - combined_segment = result_segments[0] - for segment in result_segments[1:]: - combined_segment = combined_segment.append(segment, crossfade=crossfade_ms) - - audio_bytes = io.BytesIO() - combined_segment.export(audio_bytes, format="mp3") - st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)") - st.audio(audio_bytes, format="audio/mp3") - - -if __name__ == "__main__": - render_audio_to_audio_interpolate() diff --git a/riffusion/streamlit/pages/interpolation.py b/riffusion/streamlit/pages/interpolation.py index f42ef62..9cde099 100644 --- a/riffusion/streamlit/pages/interpolation.py +++ b/riffusion/streamlit/pages/interpolation.py @@ -61,6 +61,12 @@ def render_interpolation() -> None: ), ) + guidance = st.sidebar.number_input( + "Guidance", + value=7.0, + help="How much the model listens to the text prompt", + ) + init_image_name = st.sidebar.selectbox( "Seed image", # TODO(hayk): Read from directory @@ -96,11 +102,11 @@ def render_interpolation() -> None: with left: st.write("##### Prompt A") - prompt_input_a = get_prompt_inputs(key="a") + prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a")) with right: st.write("##### Prompt B") - prompt_input_b = get_prompt_inputs(key="b") + prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b")) st.form_submit_button("Generate", type="primary") @@ -108,11 +114,15 @@ def render_interpolation() -> None: st.info("Enter both prompts to interpolate between them") return - # TODO(hayk): Make not linspace alphas = list(np.linspace(0, 1, num_interpolation_steps)) alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) st.write(f"**Alphas** : [{alphas_str}]") + # TODO(hayk): Apply scaling to alphas like this + # T_shifted = T * 2 - 1 + # T_sample = (np.abs(T_shifted)**t_scale_power * np.sign(T_shifted) + 1) / 2 + # T_sample = T_sample * (t_end - t_start) + t_start + if init_image_name == "custom": if not init_image_file: st.info("Upload a custom seed image") @@ -171,36 +181,43 @@ def render_interpolation() -> None: st.audio(audio_bytes) -def get_prompt_inputs(key: str) -> PromptInput: +def get_prompt_inputs( + key: str, + include_negative_prompt: bool = False, + cols: bool = False, +) -> T.Dict[str, T.Any]: """ Compute prompt inputs from widgets. """ - prompt = st.text_input("Prompt", label_visibility="collapsed", key=f"prompt_{key}") - seed = T.cast( + p: T.Dict[str, T.Any] = {} + + # Optionally use columns + left, right = T.cast(T.Any, st.columns(2) if cols else (st, st)) + + visibility = "visible" if cols else "collapsed" + p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}") + + if include_negative_prompt: + p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}") + + p["seed"] = T.cast( int, - st.number_input( + left.number_input( "Seed", value=42, key=f"seed_{key}", help="Integer used to generate a random result. Vary this to explore alternatives.", ), ) - denoising = st.number_input( - "Denoising", value=0.75, key=f"denoising_{key}", help="How much to modify the seed image" - ) - guidance = st.number_input( - "Guidance", - value=7.0, - key=f"guidance_{key}", - help="How much the model listens to the text prompt", + + p["denoising"] = right.number_input( + "Denoising", + value=0.5, + key=f"denoising_{key}", + help="How much to modify the seed image", ) - return PromptInput( - prompt=prompt, - seed=seed, - denoising=denoising, - guidance=guidance, - ) + return p @st.experimental_memo diff --git a/riffusion/streamlit/pages/split_audio.py b/riffusion/streamlit/pages/split_audio.py index dc5c235..a69e1b2 100644 --- a/riffusion/streamlit/pages/split_audio.py +++ b/riffusion/streamlit/pages/split_audio.py @@ -13,7 +13,7 @@ def render_split_audio() -> None: st.subheader(":scissors: Audio Splitter") st.write( """ - Split an audio into stems of {vocals, drums, bass, other}. + Split an audio into stems of {vocals, drums, bass, piano, guitar, other}. """ ) diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index dca6e8f..7d4edc7 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -194,8 +194,8 @@ def run_img2img( denoising_strength: float, num_inference_steps: int, guidance_scale: float, - negative_prompt: str, seed: int, + negative_prompt: T.Optional[str] = None, device: str = "cuda", progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, ) -> Image.Image: diff --git a/riffusion/util/audio_util.py b/riffusion/util/audio_util.py index 858c0d1..aa66e7a 100644 --- a/riffusion/util/audio_util.py +++ b/riffusion/util/audio_util.py @@ -3,6 +3,7 @@ Audio utility functions. """ import io +import typing as T import numpy as np import pydub @@ -69,3 +70,16 @@ def apply_filters(segment: pydub.AudioSegment, compression: bool = False) -> pyd ) return segment + + +def stitch_segments( + segments: T.Sequence[pydub.AudioSegment], crossfade_s: float +) -> pydub.AudioSegment: + """ + Stitch together a sequence of audio segments with a crossfade between each segment. + """ + crossfade_ms = int(crossfade_s * 1000) + combined_segment = segments[0] + for segment in segments[1:]: + combined_segment = combined_segment.append(segment, crossfade=crossfade_ms) + return combined_segment