riffusion-inference/riffusion/streamlit/pages/audio_to_audio_interpolate.py

255 lines
7.8 KiB
Python

import io
import typing as T
import numpy as np
import pydub
import streamlit as st
from PIL import Image
from riffusion.datatypes import InferenceInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
def render_audio_to_audio_interpolate() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
st.subheader(":wave: Audio to Audio Inteprolation")
st.write(
"""
Audio to audio with interpolation.
"""
)
with st.expander("Help", False):
st.write(
"""
TODO
"""
)
device = streamlit_util.select_device(st.sidebar)
num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
label_visibility="collapsed",
)
if not audio_file:
st.info("Upload audio to get started")
return
st.write("#### Original")
st.audio(audio_file)
segment = streamlit_util.load_audio_file(audio_file)
# TODO(hayk): Fix
segment = segment.set_frame_rate(44100)
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
if "counter" not in st.session_state:
st.session_state.counter = 0
def increment_counter():
st.session_state.counter += 1
cols = st.columns(4)
start_time_s = cols[0].number_input(
"Start Time [s]",
min_value=0.0,
value=0.0,
)
duration_s = cols[1].number_input(
"Duration [s]",
min_value=0.0,
value=15.0,
)
clip_duration_s = cols[2].number_input(
"Clip Duration [s]",
min_value=3.0,
max_value=10.0,
value=5.0,
)
overlap_duration_s = cols[3].number_input(
"Overlap Duration [s]",
min_value=0.0,
max_value=10.0,
value=0.2,
)
duration_s = min(duration_s, segment.duration_seconds - start_time_s)
increment_s = clip_duration_s - overlap_duration_s
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
st.write(
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s."
)
with st.expander("Clip Times"):
st.dataframe(
{
"Start Time [s]": clip_start_times,
"End Time [s]": clip_start_times + clip_duration_s,
"Duration [s]": clip_duration_s,
}
)
with st.form(key="interpolation_form"):
left, right = st.columns(2)
with left:
st.write("##### Prompt A")
prompt_input_a = get_prompt_inputs(key="a")
with right:
st.write("##### Prompt B")
prompt_input_b = get_prompt_inputs(key="b")
submit_button = st.form_submit_button("Generate", type="primary")
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)
clip_segments: T.List[pydub.AudioSegment] = []
for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
# TODO(hayk): I don't think this is working properly
if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
clip_segments.append(clip_segment)
if not prompt_input_a.prompt or not prompt_input_b.prompt:
st.info("Enter both prompts to interpolate between them")
return
if not submit_button:
return
params = SpectrogramParams()
# TODO(hayk): Make not linspace
alphas = list(np.linspace(0, 1, len(clip_segments)))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")
result_images: T.List[Image.Image] = []
result_segments: T.List[pydub.AudioSegment] = []
for i, clip_segment in enumerate(clip_segments):
st.write(f"### Clip {i} at {clip_start_times[i]}s")
audio_bytes = io.BytesIO()
clip_segment.export(audio_bytes, format="wav")
init_image = streamlit_util.spectrogram_image_from_audio(
clip_segment,
params=params,
device=device,
)
# TODO(hayk): Roll this into spectrogram_image_from_audio?
# TODO(hayk): Scale something when computing audio
closest_width = int(np.ceil(init_image.width / 32) * 32)
closest_height = int(np.ceil(init_image.height / 32) * 32)
init_image_resized = init_image.resize((closest_width, closest_height), Image.BICUBIC)
# progress_callback = None
if show_clip_details:
left, right = st.columns(2)
left.write("##### Source Clip")
left.image(init_image, use_column_width=False)
left.audio(audio_bytes)
right.write("##### Riffed Clip")
empty_bin = right.empty()
with empty_bin.container():
st.info("Riffing...")
# progress = st.progress(0.0)
# progress_callback = progress.progress
inputs = InferenceInput(
alpha=float(alphas[i]),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)
image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image_resized,
device=device,
)
# Resize back to original size
image = image.resize(init_image.size, Image.BICUBIC)
result_images.append(image)
if show_clip_details:
empty_bin.empty()
right.image(image, use_column_width=False)
riffed_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=image,
params=params,
device=device,
)
result_segments.append(riffed_segment)
if show_clip_details:
right.audio(audio_bytes)
if show_clip_details and show_difference:
diff_np = np.maximum(
0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32)
)
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=diff_image,
params=params,
device=device,
)
audio_bytes = io.BytesIO()
diff_segment.export(audio_bytes, format="wav")
st.audio(audio_bytes)
# Combine clips with a crossfade based on overlap
crossfade_ms = int(overlap_duration_s * 1000)
combined_segment = result_segments[0]
for segment in result_segments[1:]:
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
audio_bytes = io.BytesIO()
combined_segment.export(audio_bytes, format="mp3")
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
st.audio(audio_bytes, format="audio/mp3")
@st.cache
def test(segment: pydub.AudioSegment, counter: int) -> int:
st.write("#### Trimmed")
st.write(segment.duration_seconds)
return counter
if __name__ == "__main__":
render_audio_to_audio_interpolate()