Audio to audio improvements (some WIP)

Topic: audio_to_audio
This commit is contained in:
Hayk Martiros 2023-01-05 04:44:15 +00:00
parent 503c5e4e84
commit 83b2792b27
1 changed files with 54 additions and 21 deletions

View File

@ -4,19 +4,12 @@ import typing as T
import numpy as np import numpy as np
import pydub import pydub
import streamlit as st import streamlit as st
import torch
from PIL import Image from PIL import Image
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util from riffusion.streamlit import util as streamlit_util
@st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)
def render_audio_to_audio() -> None: def render_audio_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸") st.set_page_config(layout="wide", page_icon="🎸")
@ -31,7 +24,7 @@ def render_audio_to_audio() -> None:
audio_file = st.file_uploader( audio_file = st.file_uploader(
"Upload audio", "Upload audio",
type=["mp3", "ogg", "wav", "flac"], type=["mp3", "m4a", "ogg", "wav", "flac"],
label_visibility="collapsed", label_visibility="collapsed",
) )
@ -39,10 +32,14 @@ def render_audio_to_audio() -> None:
st.info("Upload audio to get started") st.info("Upload audio to get started")
return return
st.write("#### Original Audio") st.write("#### Original")
st.audio(audio_file) st.audio(audio_file)
segment = load_audio_file(audio_file) segment = streamlit_util.load_audio_file(audio_file)
# TODO(hayk): Fix
segment = segment.set_frame_rate(44100)
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
if "counter" not in st.session_state: if "counter" not in st.session_state:
st.session_state.counter = 0 st.session_state.counter = 0
@ -59,7 +56,6 @@ def render_audio_to_audio() -> None:
duration_s = cols[1].number_input( duration_s = cols[1].number_input(
"Duration [s]", "Duration [s]",
min_value=0.0, min_value=0.0,
max_value=segment.duration_seconds,
value=15.0, value=15.0,
) )
clip_duration_s = cols[2].number_input( clip_duration_s = cols[2].number_input(
@ -75,12 +71,14 @@ def render_audio_to_audio() -> None:
value=0.2, value=0.2,
) )
duration_s = min(duration_s, segment.duration_seconds - start_time_s)
increment_s = clip_duration_s - overlap_duration_s increment_s = clip_duration_s - overlap_duration_s
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s) clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
st.write( st.write(
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s " f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s." f"with overlap {overlap_duration_s}s."
) )
st.write(clip_start_times)
with st.form("Conversion Params"): with st.form("Conversion Params"):
@ -92,7 +90,7 @@ def render_audio_to_audio() -> None:
"Denoising Strength", "Denoising Strength",
min_value=0.0, min_value=0.0,
max_value=1.0, max_value=1.0,
value=0.65, value=0.45,
) )
guidance_scale = cols[1].number_input( guidance_scale = cols[1].number_input(
"Guidance Scale", "Guidance Scale",
@ -108,27 +106,37 @@ def render_audio_to_audio() -> None:
value=50, value=50,
) )
) )
seed = int( seed = int(
cols[3].number_input( cols[3].number_input(
"Seed", "Seed",
min_value=-1, min_value=0,
value=-1, value=42,
) )
) )
# TODO replace seed -1 with random
submit_button = st.form_submit_button("Convert", on_click=increment_counter) submit_button = st.form_submit_button("Convert", on_click=increment_counter)
# TODO fix # TODO fix
show_clip_details = st.sidebar.checkbox("Show Clip Details", True) show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)
clip_segments: T.List[pydub.AudioSegment] = [] clip_segments: T.List[pydub.AudioSegment] = []
for clip_start_time_s in clip_start_times: for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000) clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000) clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms] clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
st.write(f"Last clip: {clip_duration_ms=}ms")
st.write(f"Last clip: {clip_start_time_ms=}ms")
st.write(f"Last clip: {clip_segment.duration_seconds=:.2f}s")
st.write(f"Last clip: {silence_ms=}ms")
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
clip_segments.append(clip_segment) clip_segments.append(clip_segment)
if not prompt: if not prompt:
@ -154,6 +162,13 @@ def render_audio_to_audio() -> None:
device=device, device=device,
) )
# TODO(hayk): Roll this into spectrogram_image_from_audio?
# TODO(hayk): Scale something when computing audio
closest_width = int(np.ceil(init_image.width / 32) * 32)
closest_height = int(np.ceil(init_image.height / 32) * 32)
init_image = init_image.resize((closest_width, closest_height), Image.BICUBIC)
progress_callback = None
if show_clip_details: if show_clip_details:
left, right = st.columns(2) left, right = st.columns(2)
@ -166,6 +181,7 @@ def render_audio_to_audio() -> None:
with empty_bin.container(): with empty_bin.container():
st.info("Riffing...") st.info("Riffing...")
progress = st.progress(0.0) progress = st.progress(0.0)
progress_callback = progress.progress
image = streamlit_util.run_img2img( image = streamlit_util.run_img2img(
prompt=prompt, prompt=prompt,
@ -175,10 +191,11 @@ def render_audio_to_audio() -> None:
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
seed=seed, seed=seed,
progress_callback=progress.progress, progress_callback=progress_callback,
device=device, device=device,
) )
st.write(init_image.size)
st.write(image.size)
result_images.append(image) result_images.append(image)
if show_clip_details: if show_clip_details:
@ -191,13 +208,29 @@ def render_audio_to_audio() -> None:
device=device, device=device,
) )
result_segments.append(riffed_segment) result_segments.append(riffed_segment)
st.write(clip_segment.duration_seconds)
st.write(riffed_segment.duration_seconds)
audio_bytes = io.BytesIO() audio_bytes = io.BytesIO()
riffed_segment.export(audio_bytes, format="wav") riffed_segment.export(audio_bytes, format="wav")
if show_clip_details: if show_clip_details:
right.audio(audio_bytes) right.audio(audio_bytes)
if show_clip_details and show_difference:
diff_np = np.maximum(0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32))
st.write(diff_np.shape)
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
st.image(diff_image)
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=diff_image,
params=params,
device=device,
)
audio_bytes = io.BytesIO()
diff_segment.export(audio_bytes, format="wav")
st.audio(audio_bytes)
# Combine clips with a crossfade based on overlap # Combine clips with a crossfade based on overlap
crossfade_ms = int(overlap_duration_s * 1000) crossfade_ms = int(overlap_duration_s * 1000)
combined_segment = result_segments[0] combined_segment = result_segments[0]
@ -207,7 +240,7 @@ def render_audio_to_audio() -> None:
audio_bytes = io.BytesIO() audio_bytes = io.BytesIO()
combined_segment.export(audio_bytes, format="mp3") combined_segment.export(audio_bytes, format="mp3")
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)") st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
st.audio(audio_bytes) st.audio(audio_bytes, format="audio/mp3")
@st.cache @st.cache