Audio to audio improvements (some WIP)

Topic: audio_to_audio
This commit is contained in:
Hayk Martiros 2023-01-05 04:44:15 +00:00
parent 503c5e4e84
commit 83b2792b27
1 changed files with 54 additions and 21 deletions

View File

@ -4,19 +4,12 @@ import typing as T
import numpy as np
import pydub
import streamlit as st
import torch
from PIL import Image
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
@st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)
def render_audio_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")
@ -31,7 +24,7 @@ def render_audio_to_audio() -> None:
audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "ogg", "wav", "flac"],
type=["mp3", "m4a", "ogg", "wav", "flac"],
label_visibility="collapsed",
)
@ -39,10 +32,14 @@ def render_audio_to_audio() -> None:
st.info("Upload audio to get started")
return
st.write("#### Original Audio")
st.write("#### Original")
st.audio(audio_file)
segment = load_audio_file(audio_file)
segment = streamlit_util.load_audio_file(audio_file)
# TODO(hayk): Fix
segment = segment.set_frame_rate(44100)
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
if "counter" not in st.session_state:
st.session_state.counter = 0
@ -59,7 +56,6 @@ def render_audio_to_audio() -> None:
duration_s = cols[1].number_input(
"Duration [s]",
min_value=0.0,
max_value=segment.duration_seconds,
value=15.0,
)
clip_duration_s = cols[2].number_input(
@ -75,12 +71,14 @@ def render_audio_to_audio() -> None:
value=0.2,
)
duration_s = min(duration_s, segment.duration_seconds - start_time_s)
increment_s = clip_duration_s - overlap_duration_s
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
st.write(
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s."
)
st.write(clip_start_times)
with st.form("Conversion Params"):
@ -92,7 +90,7 @@ def render_audio_to_audio() -> None:
"Denoising Strength",
min_value=0.0,
max_value=1.0,
value=0.65,
value=0.45,
)
guidance_scale = cols[1].number_input(
"Guidance Scale",
@ -108,27 +106,37 @@ def render_audio_to_audio() -> None:
value=50,
)
)
seed = int(
cols[3].number_input(
"Seed",
min_value=-1,
value=-1,
min_value=0,
value=42,
)
)
# TODO replace seed -1 with random
submit_button = st.form_submit_button("Convert", on_click=increment_counter)
# TODO fix
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)
clip_segments: T.List[pydub.AudioSegment] = []
for clip_start_time_s in clip_start_times:
for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
st.write(f"Last clip: {clip_duration_ms=}ms")
st.write(f"Last clip: {clip_start_time_ms=}ms")
st.write(f"Last clip: {clip_segment.duration_seconds=:.2f}s")
st.write(f"Last clip: {silence_ms=}ms")
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
clip_segments.append(clip_segment)
if not prompt:
@ -154,6 +162,13 @@ def render_audio_to_audio() -> None:
device=device,
)
# TODO(hayk): Roll this into spectrogram_image_from_audio?
# TODO(hayk): Scale something when computing audio
closest_width = int(np.ceil(init_image.width / 32) * 32)
closest_height = int(np.ceil(init_image.height / 32) * 32)
init_image = init_image.resize((closest_width, closest_height), Image.BICUBIC)
progress_callback = None
if show_clip_details:
left, right = st.columns(2)
@ -166,6 +181,7 @@ def render_audio_to_audio() -> None:
with empty_bin.container():
st.info("Riffing...")
progress = st.progress(0.0)
progress_callback = progress.progress
image = streamlit_util.run_img2img(
prompt=prompt,
@ -175,10 +191,11 @@ def render_audio_to_audio() -> None:
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
seed=seed,
progress_callback=progress.progress,
progress_callback=progress_callback,
device=device,
)
st.write(init_image.size)
st.write(image.size)
result_images.append(image)
if show_clip_details:
@ -191,13 +208,29 @@ def render_audio_to_audio() -> None:
device=device,
)
result_segments.append(riffed_segment)
st.write(clip_segment.duration_seconds)
st.write(riffed_segment.duration_seconds)
audio_bytes = io.BytesIO()
riffed_segment.export(audio_bytes, format="wav")
if show_clip_details:
right.audio(audio_bytes)
if show_clip_details and show_difference:
diff_np = np.maximum(0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32))
st.write(diff_np.shape)
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
st.image(diff_image)
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=diff_image,
params=params,
device=device,
)
audio_bytes = io.BytesIO()
diff_segment.export(audio_bytes, format="wav")
st.audio(audio_bytes)
# Combine clips with a crossfade based on overlap
crossfade_ms = int(overlap_duration_s * 1000)
combined_segment = result_segments[0]
@ -207,7 +240,7 @@ def render_audio_to_audio() -> None:
audio_bytes = io.BytesIO()
combined_segment.export(audio_bytes, format="mp3")
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
st.audio(audio_bytes)
st.audio(audio_bytes, format="audio/mp3")
@st.cache