Audio to audio handles interpolation within it
Kill the separate page. Topic: audio_to_audio_interpolation
This commit is contained in:
parent
40bf61e949
commit
8b07a5a45f
|
@ -27,6 +27,7 @@ jobs:
|
||||||
|
|
||||||
- name: Install system packages
|
- name: Install system packages
|
||||||
run: |
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
sudo apt-get install -y ffmpeg libsndfile1
|
sudo apt-get install -y ffmpeg libsndfile1
|
||||||
|
|
||||||
- name: Install pip packages from requirements.txt
|
- name: Install pip packages from requirements.txt
|
||||||
|
|
|
@ -12,6 +12,9 @@ __pycache__/
|
||||||
# Cog
|
# Cog
|
||||||
.cog/
|
.cog/
|
||||||
|
|
||||||
|
# Random stuff I don't care about
|
||||||
|
.graveyard/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
|
|
|
@ -19,6 +19,9 @@ class PromptInput:
|
||||||
# Random seed for denoising
|
# Random seed for denoising
|
||||||
seed: int
|
seed: int
|
||||||
|
|
||||||
|
# Negative prompt to avoid (optional)
|
||||||
|
negative_prompt: T.Optional[str] = None
|
||||||
|
|
||||||
# Denoising strength
|
# Denoising strength
|
||||||
denoising: float = 0.75
|
denoising: float = 0.75
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,11 @@ import pydub
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from riffusion.datatypes import InferenceInput, PromptInput
|
||||||
from riffusion.spectrogram_params import SpectrogramParams
|
from riffusion.spectrogram_params import SpectrogramParams
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
|
||||||
|
from riffusion.util import audio_util
|
||||||
|
|
||||||
|
|
||||||
def render_audio_to_audio() -> None:
|
def render_audio_to_audio() -> None:
|
||||||
|
@ -37,6 +40,19 @@ def render_audio_to_audio() -> None:
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
|
|
||||||
|
num_inference_steps = T.cast(
|
||||||
|
int,
|
||||||
|
st.sidebar.number_input(
|
||||||
|
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
guidance = st.sidebar.number_input(
|
||||||
|
"Guidance",
|
||||||
|
value=7.0,
|
||||||
|
help="How much the model listens to the text prompt",
|
||||||
|
)
|
||||||
|
|
||||||
audio_file = st.file_uploader(
|
audio_file = st.file_uploader(
|
||||||
"Upload audio",
|
"Upload audio",
|
||||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
||||||
|
@ -53,113 +69,58 @@ def render_audio_to_audio() -> None:
|
||||||
segment = streamlit_util.load_audio_file(audio_file)
|
segment = streamlit_util.load_audio_file(audio_file)
|
||||||
|
|
||||||
# TODO(hayk): Fix
|
# TODO(hayk): Fix
|
||||||
segment = segment.set_frame_rate(44100)
|
if segment.frame_rate != 44100:
|
||||||
|
st.warning("Audio must be 44100Hz. Converting")
|
||||||
|
segment = segment.set_frame_rate(44100)
|
||||||
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
|
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
|
||||||
|
|
||||||
if "counter" not in st.session_state:
|
clip_p = get_clip_params()
|
||||||
st.session_state.counter = 0
|
start_time_s = clip_p["start_time_s"]
|
||||||
|
clip_duration_s = clip_p["clip_duration_s"]
|
||||||
|
overlap_duration_s = clip_p["overlap_duration_s"]
|
||||||
|
|
||||||
def increment_counter():
|
duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s)
|
||||||
st.session_state.counter += 1
|
|
||||||
|
|
||||||
cols = st.columns(4)
|
|
||||||
start_time_s = cols[0].number_input(
|
|
||||||
"Start Time [s]",
|
|
||||||
min_value=0.0,
|
|
||||||
value=0.0,
|
|
||||||
)
|
|
||||||
duration_s = cols[1].number_input(
|
|
||||||
"Duration [s]",
|
|
||||||
min_value=0.0,
|
|
||||||
value=15.0,
|
|
||||||
)
|
|
||||||
clip_duration_s = cols[2].number_input(
|
|
||||||
"Clip Duration [s]",
|
|
||||||
min_value=3.0,
|
|
||||||
max_value=10.0,
|
|
||||||
value=5.0,
|
|
||||||
)
|
|
||||||
overlap_duration_s = cols[3].number_input(
|
|
||||||
"Overlap Duration [s]",
|
|
||||||
min_value=0.0,
|
|
||||||
max_value=10.0,
|
|
||||||
value=0.2,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration_s = min(duration_s, segment.duration_seconds - start_time_s)
|
|
||||||
increment_s = clip_duration_s - overlap_duration_s
|
increment_s = clip_duration_s - overlap_duration_s
|
||||||
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
|
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
|
||||||
st.write(
|
|
||||||
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
|
write_clip_details(
|
||||||
f"with overlap {overlap_duration_s}s."
|
clip_start_times=clip_start_times,
|
||||||
|
clip_duration_s=clip_duration_s,
|
||||||
|
overlap_duration_s=overlap_duration_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
with st.expander("Clip Times"):
|
interpolate = st.checkbox("Interpolate between two settings", False)
|
||||||
st.dataframe(
|
|
||||||
{
|
|
||||||
"Start Time [s]": clip_start_times,
|
|
||||||
"End Time [s]": clip_start_times + clip_duration_s,
|
|
||||||
"Duration [s]": clip_duration_s,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with st.form("Conversion Params"):
|
with st.form("audio to audio form"):
|
||||||
|
if interpolate:
|
||||||
|
left, right = st.columns(2)
|
||||||
|
|
||||||
prompt = st.text_input("Text Prompt")
|
with left:
|
||||||
negative_prompt = st.text_input("Negative Prompt")
|
st.write("##### Prompt A")
|
||||||
|
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
|
||||||
|
|
||||||
cols = st.columns(4)
|
with right:
|
||||||
denoising_strength = cols[0].number_input(
|
st.write("##### Prompt B")
|
||||||
"Denoising Strength",
|
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
|
||||||
min_value=0.0,
|
|
||||||
max_value=1.0,
|
else:
|
||||||
value=0.45,
|
prompt_input_a = PromptInput(
|
||||||
)
|
guidance=guidance,
|
||||||
guidance_scale = cols[1].number_input(
|
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
|
||||||
"Guidance Scale",
|
|
||||||
min_value=0.0,
|
|
||||||
max_value=20.0,
|
|
||||||
value=7.0,
|
|
||||||
)
|
|
||||||
num_inference_steps = int(
|
|
||||||
cols[2].number_input(
|
|
||||||
"Num Inference Steps",
|
|
||||||
min_value=1,
|
|
||||||
max_value=150,
|
|
||||||
value=50,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
seed = int(
|
submit_button = st.form_submit_button("Riff", type="primary")
|
||||||
cols[3].number_input(
|
|
||||||
"Seed",
|
|
||||||
min_value=0,
|
|
||||||
value=42,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
submit_button = st.form_submit_button("Convert", on_click=increment_counter)
|
|
||||||
|
|
||||||
# TODO fix
|
|
||||||
|
|
||||||
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
||||||
show_difference = st.sidebar.checkbox("Show Difference", False)
|
show_difference = st.sidebar.checkbox("Show Difference", False)
|
||||||
|
|
||||||
clip_segments: T.List[pydub.AudioSegment] = []
|
clip_segments = slice_audio_into_clips(
|
||||||
for i, clip_start_time_s in enumerate(clip_start_times):
|
segment=segment,
|
||||||
clip_start_time_ms = int(clip_start_time_s * 1000)
|
clip_start_times=clip_start_times,
|
||||||
clip_duration_ms = int(clip_duration_s * 1000)
|
clip_duration_s=clip_duration_s,
|
||||||
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
|
)
|
||||||
|
|
||||||
# TODO(hayk): I don't think this is working properly
|
if not prompt_input_a.prompt:
|
||||||
if i == len(clip_start_times) - 1:
|
|
||||||
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
|
|
||||||
if silence_ms > 0:
|
|
||||||
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
|
|
||||||
|
|
||||||
clip_segments.append(clip_segment)
|
|
||||||
|
|
||||||
if not prompt:
|
|
||||||
st.info("Enter a prompt")
|
st.info("Enter a prompt")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -168,10 +129,16 @@ def render_audio_to_audio() -> None:
|
||||||
|
|
||||||
params = SpectrogramParams()
|
params = SpectrogramParams()
|
||||||
|
|
||||||
|
if interpolate:
|
||||||
|
# TODO(hayk): Make not linspace
|
||||||
|
alphas = list(np.linspace(0, 1, len(clip_segments)))
|
||||||
|
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||||
|
st.write(f"**Alphas** : [{alphas_str}]")
|
||||||
|
|
||||||
result_images: T.List[Image.Image] = []
|
result_images: T.List[Image.Image] = []
|
||||||
result_segments: T.List[pydub.AudioSegment] = []
|
result_segments: T.List[pydub.AudioSegment] = []
|
||||||
for i, clip_segment in enumerate(clip_segments):
|
for i, clip_segment in enumerate(clip_segments):
|
||||||
st.write(f"### Clip {i} at {clip_start_times[i]}s")
|
st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s")
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
audio_bytes = io.BytesIO()
|
||||||
clip_segment.export(audio_bytes, format="wav")
|
clip_segment.export(audio_bytes, format="wav")
|
||||||
|
@ -183,10 +150,7 @@ def render_audio_to_audio() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(hayk): Roll this into spectrogram_image_from_audio?
|
# TODO(hayk): Roll this into spectrogram_image_from_audio?
|
||||||
# TODO(hayk): Scale something when computing audio
|
init_image_resized = scale_image_to_32_stride(init_image)
|
||||||
closest_width = int(np.ceil(init_image.width / 32) * 32)
|
|
||||||
closest_height = int(np.ceil(init_image.height / 32) * 32)
|
|
||||||
init_image_resized = init_image.resize((closest_width, closest_height), Image.BICUBIC)
|
|
||||||
|
|
||||||
progress_callback = None
|
progress_callback = None
|
||||||
if show_clip_details:
|
if show_clip_details:
|
||||||
|
@ -203,17 +167,32 @@ def render_audio_to_audio() -> None:
|
||||||
progress = st.progress(0.0)
|
progress = st.progress(0.0)
|
||||||
progress_callback = progress.progress
|
progress_callback = progress.progress
|
||||||
|
|
||||||
image = streamlit_util.run_img2img(
|
if interpolate:
|
||||||
prompt=prompt,
|
inputs = InferenceInput(
|
||||||
init_image=init_image_resized,
|
alpha=float(alphas[i]),
|
||||||
denoising_strength=denoising_strength,
|
num_inference_steps=num_inference_steps,
|
||||||
num_inference_steps=num_inference_steps,
|
seed_image_id="og_beat",
|
||||||
guidance_scale=guidance_scale,
|
start=prompt_input_a,
|
||||||
negative_prompt=negative_prompt,
|
end=prompt_input_b,
|
||||||
seed=seed,
|
)
|
||||||
progress_callback=progress_callback,
|
|
||||||
device=device,
|
image, audio_bytes = run_interpolation(
|
||||||
)
|
inputs=inputs,
|
||||||
|
init_image=init_image_resized,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image = streamlit_util.run_img2img(
|
||||||
|
prompt=prompt_input_a.prompt,
|
||||||
|
init_image=init_image_resized,
|
||||||
|
denoising_strength=prompt_input_a.denoising,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
guidance_scale=guidance,
|
||||||
|
negative_prompt=prompt_input_a.negative_prompt,
|
||||||
|
seed=prompt_input_a.seed,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# Resize back to original size
|
# Resize back to original size
|
||||||
image = image.resize(init_image.size, Image.BICUBIC)
|
image = image.resize(init_image.size, Image.BICUBIC)
|
||||||
|
@ -253,10 +232,7 @@ def render_audio_to_audio() -> None:
|
||||||
st.audio(audio_bytes)
|
st.audio(audio_bytes)
|
||||||
|
|
||||||
# Combine clips with a crossfade based on overlap
|
# Combine clips with a crossfade based on overlap
|
||||||
crossfade_ms = int(overlap_duration_s * 1000)
|
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
|
||||||
combined_segment = result_segments[0]
|
|
||||||
for segment in result_segments[1:]:
|
|
||||||
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
audio_bytes = io.BytesIO()
|
||||||
combined_segment.export(audio_bytes, format="mp3")
|
combined_segment.export(audio_bytes, format="mp3")
|
||||||
|
@ -264,11 +240,99 @@ def render_audio_to_audio() -> None:
|
||||||
st.audio(audio_bytes, format="audio/mp3")
|
st.audio(audio_bytes, format="audio/mp3")
|
||||||
|
|
||||||
|
|
||||||
@st.cache
|
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
|
||||||
def test(segment: pydub.AudioSegment, counter: int) -> int:
|
"""
|
||||||
st.write("#### Trimmed")
|
Render the parameters of slicing audio into clips.
|
||||||
st.write(segment.duration_seconds)
|
"""
|
||||||
return counter
|
p: T.Dict[str, T.Any] = {}
|
||||||
|
|
||||||
|
cols = st.columns(4)
|
||||||
|
|
||||||
|
p["start_time_s"] = cols[0].number_input(
|
||||||
|
"Start Time [s]",
|
||||||
|
min_value=0.0,
|
||||||
|
value=0.0,
|
||||||
|
)
|
||||||
|
p["duration_s"] = cols[1].number_input(
|
||||||
|
"Duration [s]",
|
||||||
|
min_value=0.0,
|
||||||
|
value=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if advanced:
|
||||||
|
p["clip_duration_s"] = cols[2].number_input(
|
||||||
|
"Clip Duration [s]",
|
||||||
|
min_value=3.0,
|
||||||
|
max_value=10.0,
|
||||||
|
value=5.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
p["clip_duration_s"] = 5.0
|
||||||
|
|
||||||
|
if advanced:
|
||||||
|
p["overlap_duration_s"] = cols[3].number_input(
|
||||||
|
"Overlap Duration [s]",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=10.0,
|
||||||
|
value=0.2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
p["overlap_duration_s"] = 0.2
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def write_clip_details(
|
||||||
|
clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Write details of the clips to be sliced from an audio segment.
|
||||||
|
"""
|
||||||
|
clip_details_text = (
|
||||||
|
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
|
||||||
|
f"with overlap {overlap_duration_s}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
with st.expander(clip_details_text):
|
||||||
|
st.dataframe(
|
||||||
|
{
|
||||||
|
"Start Time [s]": clip_start_times,
|
||||||
|
"End Time [s]": clip_start_times + clip_duration_s,
|
||||||
|
"Duration [s]": clip_duration_s,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def slice_audio_into_clips(
|
||||||
|
segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float
|
||||||
|
) -> T.List[pydub.AudioSegment]:
|
||||||
|
"""
|
||||||
|
Slice an audio segment into a list of clips of a given duration at the given start times.
|
||||||
|
"""
|
||||||
|
clip_segments: T.List[pydub.AudioSegment] = []
|
||||||
|
for i, clip_start_time_s in enumerate(clip_start_times):
|
||||||
|
clip_start_time_ms = int(clip_start_time_s * 1000)
|
||||||
|
clip_duration_ms = int(clip_duration_s * 1000)
|
||||||
|
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
|
||||||
|
|
||||||
|
# TODO(hayk): I don't think this is working properly
|
||||||
|
if i == len(clip_start_times) - 1:
|
||||||
|
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
|
||||||
|
if silence_ms > 0:
|
||||||
|
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
|
||||||
|
|
||||||
|
clip_segments.append(clip_segment)
|
||||||
|
|
||||||
|
return clip_segments
|
||||||
|
|
||||||
|
|
||||||
|
def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Scale an image to a size that is a multiple of 32.
|
||||||
|
"""
|
||||||
|
closest_width = int(np.ceil(image.width / 32) * 32)
|
||||||
|
closest_height = int(np.ceil(image.height / 32) * 32)
|
||||||
|
return image.resize((closest_width, closest_height), Image.BICUBIC)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,247 +0,0 @@
|
||||||
import io
|
|
||||||
import typing as T
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pydub
|
|
||||||
import streamlit as st
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from riffusion.datatypes import InferenceInput
|
|
||||||
from riffusion.spectrogram_params import SpectrogramParams
|
|
||||||
from riffusion.streamlit import util as streamlit_util
|
|
||||||
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
|
|
||||||
|
|
||||||
|
|
||||||
def render_audio_to_audio_interpolate() -> None:
|
|
||||||
st.set_page_config(layout="wide", page_icon="🎸")
|
|
||||||
|
|
||||||
st.subheader(":wave: Audio to Audio Inteprolation")
|
|
||||||
st.write(
|
|
||||||
"""
|
|
||||||
Audio to audio with interpolation.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
with st.expander("Help", False):
|
|
||||||
st.write(
|
|
||||||
"""
|
|
||||||
TODO
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
|
||||||
|
|
||||||
num_inference_steps = T.cast(
|
|
||||||
int,
|
|
||||||
st.sidebar.number_input(
|
|
||||||
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_file = st.file_uploader(
|
|
||||||
"Upload audio",
|
|
||||||
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
|
|
||||||
label_visibility="collapsed",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not audio_file:
|
|
||||||
st.info("Upload audio to get started")
|
|
||||||
return
|
|
||||||
|
|
||||||
st.write("#### Original")
|
|
||||||
st.audio(audio_file)
|
|
||||||
|
|
||||||
segment = streamlit_util.load_audio_file(audio_file)
|
|
||||||
|
|
||||||
# TODO(hayk): Fix
|
|
||||||
segment = segment.set_frame_rate(44100)
|
|
||||||
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
|
|
||||||
|
|
||||||
if "counter" not in st.session_state:
|
|
||||||
st.session_state.counter = 0
|
|
||||||
|
|
||||||
def increment_counter():
|
|
||||||
st.session_state.counter += 1
|
|
||||||
|
|
||||||
cols = st.columns(4)
|
|
||||||
start_time_s = cols[0].number_input(
|
|
||||||
"Start Time [s]",
|
|
||||||
min_value=0.0,
|
|
||||||
value=0.0,
|
|
||||||
)
|
|
||||||
duration_s = cols[1].number_input(
|
|
||||||
"Duration [s]",
|
|
||||||
min_value=0.0,
|
|
||||||
value=15.0,
|
|
||||||
)
|
|
||||||
clip_duration_s = cols[2].number_input(
|
|
||||||
"Clip Duration [s]",
|
|
||||||
min_value=3.0,
|
|
||||||
max_value=10.0,
|
|
||||||
value=5.0,
|
|
||||||
)
|
|
||||||
overlap_duration_s = cols[3].number_input(
|
|
||||||
"Overlap Duration [s]",
|
|
||||||
min_value=0.0,
|
|
||||||
max_value=10.0,
|
|
||||||
value=0.2,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration_s = min(duration_s, segment.duration_seconds - start_time_s)
|
|
||||||
increment_s = clip_duration_s - overlap_duration_s
|
|
||||||
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
|
|
||||||
st.write(
|
|
||||||
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
|
|
||||||
f"with overlap {overlap_duration_s}s."
|
|
||||||
)
|
|
||||||
|
|
||||||
with st.expander("Clip Times"):
|
|
||||||
st.dataframe(
|
|
||||||
{
|
|
||||||
"Start Time [s]": clip_start_times,
|
|
||||||
"End Time [s]": clip_start_times + clip_duration_s,
|
|
||||||
"Duration [s]": clip_duration_s,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with st.form(key="interpolation_form"):
|
|
||||||
left, right = st.columns(2)
|
|
||||||
|
|
||||||
with left:
|
|
||||||
st.write("##### Prompt A")
|
|
||||||
prompt_input_a = get_prompt_inputs(key="a")
|
|
||||||
|
|
||||||
with right:
|
|
||||||
st.write("##### Prompt B")
|
|
||||||
prompt_input_b = get_prompt_inputs(key="b")
|
|
||||||
|
|
||||||
submit_button = st.form_submit_button("Generate", type="primary")
|
|
||||||
|
|
||||||
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
|
||||||
show_difference = st.sidebar.checkbox("Show Difference", False)
|
|
||||||
|
|
||||||
clip_segments: T.List[pydub.AudioSegment] = []
|
|
||||||
for i, clip_start_time_s in enumerate(clip_start_times):
|
|
||||||
clip_start_time_ms = int(clip_start_time_s * 1000)
|
|
||||||
clip_duration_ms = int(clip_duration_s * 1000)
|
|
||||||
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
|
|
||||||
|
|
||||||
# TODO(hayk): I don't think this is working properly
|
|
||||||
if i == len(clip_start_times) - 1:
|
|
||||||
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
|
|
||||||
if silence_ms > 0:
|
|
||||||
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
|
|
||||||
|
|
||||||
clip_segments.append(clip_segment)
|
|
||||||
|
|
||||||
if not prompt_input_a.prompt or not prompt_input_b.prompt:
|
|
||||||
st.info("Enter both prompts to interpolate between them")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not submit_button:
|
|
||||||
return
|
|
||||||
|
|
||||||
params = SpectrogramParams()
|
|
||||||
|
|
||||||
# TODO(hayk): Make not linspace
|
|
||||||
alphas = list(np.linspace(0, 1, len(clip_segments)))
|
|
||||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
|
||||||
st.write(f"**Alphas** : [{alphas_str}]")
|
|
||||||
|
|
||||||
result_images: T.List[Image.Image] = []
|
|
||||||
result_segments: T.List[pydub.AudioSegment] = []
|
|
||||||
for i, clip_segment in enumerate(clip_segments):
|
|
||||||
st.write(f"### Clip {i} at {clip_start_times[i]}s")
|
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
|
||||||
clip_segment.export(audio_bytes, format="wav")
|
|
||||||
|
|
||||||
init_image = streamlit_util.spectrogram_image_from_audio(
|
|
||||||
clip_segment,
|
|
||||||
params=params,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(hayk): Roll this into spectrogram_image_from_audio?
|
|
||||||
# TODO(hayk): Scale something when computing audio
|
|
||||||
closest_width = int(np.ceil(init_image.width / 32) * 32)
|
|
||||||
closest_height = int(np.ceil(init_image.height / 32) * 32)
|
|
||||||
init_image_resized = init_image.resize((closest_width, closest_height), Image.BICUBIC)
|
|
||||||
|
|
||||||
# progress_callback = None
|
|
||||||
if show_clip_details:
|
|
||||||
left, right = st.columns(2)
|
|
||||||
|
|
||||||
left.write("##### Source Clip")
|
|
||||||
left.image(init_image, use_column_width=False)
|
|
||||||
left.audio(audio_bytes)
|
|
||||||
|
|
||||||
right.write("##### Riffed Clip")
|
|
||||||
empty_bin = right.empty()
|
|
||||||
with empty_bin.container():
|
|
||||||
st.info("Riffing...")
|
|
||||||
# progress = st.progress(0.0)
|
|
||||||
# progress_callback = progress.progress
|
|
||||||
|
|
||||||
inputs = InferenceInput(
|
|
||||||
alpha=float(alphas[i]),
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
seed_image_id="og_beat",
|
|
||||||
start=prompt_input_a,
|
|
||||||
end=prompt_input_b,
|
|
||||||
)
|
|
||||||
|
|
||||||
image, audio_bytes = run_interpolation(
|
|
||||||
inputs=inputs,
|
|
||||||
init_image=init_image_resized,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resize back to original size
|
|
||||||
image = image.resize(init_image.size, Image.BICUBIC)
|
|
||||||
|
|
||||||
result_images.append(image)
|
|
||||||
|
|
||||||
if show_clip_details:
|
|
||||||
empty_bin.empty()
|
|
||||||
right.image(image, use_column_width=False)
|
|
||||||
|
|
||||||
riffed_segment = streamlit_util.audio_segment_from_spectrogram_image(
|
|
||||||
image=image,
|
|
||||||
params=params,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
result_segments.append(riffed_segment)
|
|
||||||
|
|
||||||
if show_clip_details:
|
|
||||||
right.audio(audio_bytes)
|
|
||||||
|
|
||||||
if show_clip_details and show_difference:
|
|
||||||
diff_np = np.maximum(
|
|
||||||
0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32)
|
|
||||||
)
|
|
||||||
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
|
|
||||||
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
|
|
||||||
image=diff_image,
|
|
||||||
params=params,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
|
||||||
diff_segment.export(audio_bytes, format="wav")
|
|
||||||
st.audio(audio_bytes)
|
|
||||||
|
|
||||||
# Combine clips with a crossfade based on overlap
|
|
||||||
crossfade_ms = int(overlap_duration_s * 1000)
|
|
||||||
combined_segment = result_segments[0]
|
|
||||||
for segment in result_segments[1:]:
|
|
||||||
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
|
||||||
|
|
||||||
audio_bytes = io.BytesIO()
|
|
||||||
combined_segment.export(audio_bytes, format="mp3")
|
|
||||||
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
|
|
||||||
st.audio(audio_bytes, format="audio/mp3")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
render_audio_to_audio_interpolate()
|
|
|
@ -61,6 +61,12 @@ def render_interpolation() -> None:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
guidance = st.sidebar.number_input(
|
||||||
|
"Guidance",
|
||||||
|
value=7.0,
|
||||||
|
help="How much the model listens to the text prompt",
|
||||||
|
)
|
||||||
|
|
||||||
init_image_name = st.sidebar.selectbox(
|
init_image_name = st.sidebar.selectbox(
|
||||||
"Seed image",
|
"Seed image",
|
||||||
# TODO(hayk): Read from directory
|
# TODO(hayk): Read from directory
|
||||||
|
@ -96,11 +102,11 @@ def render_interpolation() -> None:
|
||||||
|
|
||||||
with left:
|
with left:
|
||||||
st.write("##### Prompt A")
|
st.write("##### Prompt A")
|
||||||
prompt_input_a = get_prompt_inputs(key="a")
|
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
|
||||||
|
|
||||||
with right:
|
with right:
|
||||||
st.write("##### Prompt B")
|
st.write("##### Prompt B")
|
||||||
prompt_input_b = get_prompt_inputs(key="b")
|
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
|
||||||
|
|
||||||
st.form_submit_button("Generate", type="primary")
|
st.form_submit_button("Generate", type="primary")
|
||||||
|
|
||||||
|
@ -108,11 +114,15 @@ def render_interpolation() -> None:
|
||||||
st.info("Enter both prompts to interpolate between them")
|
st.info("Enter both prompts to interpolate between them")
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO(hayk): Make not linspace
|
|
||||||
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
||||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||||
st.write(f"**Alphas** : [{alphas_str}]")
|
st.write(f"**Alphas** : [{alphas_str}]")
|
||||||
|
|
||||||
|
# TODO(hayk): Apply scaling to alphas like this
|
||||||
|
# T_shifted = T * 2 - 1
|
||||||
|
# T_sample = (np.abs(T_shifted)**t_scale_power * np.sign(T_shifted) + 1) / 2
|
||||||
|
# T_sample = T_sample * (t_end - t_start) + t_start
|
||||||
|
|
||||||
if init_image_name == "custom":
|
if init_image_name == "custom":
|
||||||
if not init_image_file:
|
if not init_image_file:
|
||||||
st.info("Upload a custom seed image")
|
st.info("Upload a custom seed image")
|
||||||
|
@ -171,36 +181,43 @@ def render_interpolation() -> None:
|
||||||
st.audio(audio_bytes)
|
st.audio(audio_bytes)
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_inputs(key: str) -> PromptInput:
|
def get_prompt_inputs(
|
||||||
|
key: str,
|
||||||
|
include_negative_prompt: bool = False,
|
||||||
|
cols: bool = False,
|
||||||
|
) -> T.Dict[str, T.Any]:
|
||||||
"""
|
"""
|
||||||
Compute prompt inputs from widgets.
|
Compute prompt inputs from widgets.
|
||||||
"""
|
"""
|
||||||
prompt = st.text_input("Prompt", label_visibility="collapsed", key=f"prompt_{key}")
|
p: T.Dict[str, T.Any] = {}
|
||||||
seed = T.cast(
|
|
||||||
|
# Optionally use columns
|
||||||
|
left, right = T.cast(T.Any, st.columns(2) if cols else (st, st))
|
||||||
|
|
||||||
|
visibility = "visible" if cols else "collapsed"
|
||||||
|
p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}")
|
||||||
|
|
||||||
|
if include_negative_prompt:
|
||||||
|
p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}")
|
||||||
|
|
||||||
|
p["seed"] = T.cast(
|
||||||
int,
|
int,
|
||||||
st.number_input(
|
left.number_input(
|
||||||
"Seed",
|
"Seed",
|
||||||
value=42,
|
value=42,
|
||||||
key=f"seed_{key}",
|
key=f"seed_{key}",
|
||||||
help="Integer used to generate a random result. Vary this to explore alternatives.",
|
help="Integer used to generate a random result. Vary this to explore alternatives.",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
denoising = st.number_input(
|
|
||||||
"Denoising", value=0.75, key=f"denoising_{key}", help="How much to modify the seed image"
|
p["denoising"] = right.number_input(
|
||||||
)
|
"Denoising",
|
||||||
guidance = st.number_input(
|
value=0.5,
|
||||||
"Guidance",
|
key=f"denoising_{key}",
|
||||||
value=7.0,
|
help="How much to modify the seed image",
|
||||||
key=f"guidance_{key}",
|
|
||||||
help="How much the model listens to the text prompt",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return PromptInput(
|
return p
|
||||||
prompt=prompt,
|
|
||||||
seed=seed,
|
|
||||||
denoising=denoising,
|
|
||||||
guidance=guidance,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.experimental_memo
|
||||||
|
|
|
@ -13,7 +13,7 @@ def render_split_audio() -> None:
|
||||||
st.subheader(":scissors: Audio Splitter")
|
st.subheader(":scissors: Audio Splitter")
|
||||||
st.write(
|
st.write(
|
||||||
"""
|
"""
|
||||||
Split an audio into stems of {vocals, drums, bass, other}.
|
Split an audio into stems of {vocals, drums, bass, piano, guitar, other}.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -194,8 +194,8 @@ def run_img2img(
|
||||||
denoising_strength: float,
|
denoising_strength: float,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
guidance_scale: float,
|
guidance_scale: float,
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
seed: int,
|
||||||
|
negative_prompt: T.Optional[str] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
|
|
@ -3,6 +3,7 @@ Audio utility functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import typing as T
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pydub
|
import pydub
|
||||||
|
@ -69,3 +70,16 @@ def apply_filters(segment: pydub.AudioSegment, compression: bool = False) -> pyd
|
||||||
)
|
)
|
||||||
|
|
||||||
return segment
|
return segment
|
||||||
|
|
||||||
|
|
||||||
|
def stitch_segments(
|
||||||
|
segments: T.Sequence[pydub.AudioSegment], crossfade_s: float
|
||||||
|
) -> pydub.AudioSegment:
|
||||||
|
"""
|
||||||
|
Stitch together a sequence of audio segments with a crossfade between each segment.
|
||||||
|
"""
|
||||||
|
crossfade_ms = int(crossfade_s * 1000)
|
||||||
|
combined_segment = segments[0]
|
||||||
|
for segment in segments[1:]:
|
||||||
|
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
||||||
|
return combined_segment
|
||||||
|
|
Loading…
Reference in New Issue