diff --git a/riffusion/cli.py b/riffusion/cli.py index 8fec114..266d12b 100644 --- a/riffusion/cli.py +++ b/riffusion/cli.py @@ -78,7 +78,7 @@ def image_to_audio(*, image: str, audio: str, device: str = "cuda"): try: params = SpectrogramParams.from_exif(exif=img_exif) - except KeyError: + except (KeyError, AttributeError): print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.") params = SpectrogramParams() diff --git a/riffusion/streamlit/pages/audio_to_audio.py b/riffusion/streamlit/pages/audio_to_audio.py index 3fbae16..384db3f 100644 --- a/riffusion/streamlit/pages/audio_to_audio.py +++ b/riffusion/streamlit/pages/audio_to_audio.py @@ -78,7 +78,7 @@ def render_audio_to_audio() -> None: increment_s = clip_duration_s - overlap_duration_s clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s) st.write( - f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s" + f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s " f"with overlap {overlap_duration_s}s." ) @@ -100,82 +100,103 @@ def render_audio_to_audio() -> None: max_value=20.0, value=7.0, ) - num_inference_steps = int(cols[2].number_input( - "Num Inference Steps", - min_value=1, - max_value=150, - value=50, - )) - seed = int(cols[3].number_input( - "Seed", - min_value=-1, - value=-1, - )) + num_inference_steps = int( + cols[2].number_input( + "Num Inference Steps", + min_value=1, + max_value=150, + value=50, + ) + ) + seed = int( + cols[3].number_input( + "Seed", + min_value=-1, + value=-1, + ) + ) # TODO replace seed -1 with random submit_button = st.form_submit_button("Convert", on_click=increment_counter) # TODO fix - pipeline = streamlit_util.load_stable_diffusion_img2img_pipeline( - checkpoint="/Users/hayk/.cache/huggingface/diffusers/models--riffusion--riffusion-model-v1/snapshots/79993436c342ff529802d1dabb016ebe15b5c4ae", - device=device, - # no_traced_unet=True, - ) - st.info("Slicing up audio into clips") + + show_clip_details = st.sidebar.checkbox("Show Clip Details", True) + clip_segments: T.List[pydub.AudioSegment] = [] - for i, clip_start_time_s in enumerate(clip_start_times): + for clip_start_time_s in clip_start_times: clip_start_time_ms = int(clip_start_time_s * 1000) clip_duration_ms = int(clip_duration_s * 1000) clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms] - clip_segments.append(clip_segment) - st.write(f"#### Clip {i} at {clip_start_time_s}s") - audio_bytes = io.BytesIO() - clip_segment.export(audio_bytes, format="wav") - st.audio(audio_bytes) + if not prompt: + st.info("Enter a prompt") + return if not submit_button: return - # TODO cache params = SpectrogramParams() - converter = SpectrogramImageConverter(params=params, device=device) - st.info("Converting audio clips into spectrogram images") - init_images = [converter.spectrogram_image_from_audio(s) for s in clip_segments] - st.info("Running img2img diffusion") - result_images : T.List[Image.Image] = [] - progress = st.progress(0.0) - for segment, init_image in zip(clip_segments, init_images): - generator = torch.Generator(device="cpu").manual_seed(seed) - num_expected_steps = max(int(num_inference_steps * denoising_strength), 1) - result = pipeline( - prompt=prompt, - image=init_image, - strength=denoising_strength, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt or None, - num_images_per_prompt=1, - generator=generator, - callback=lambda i, t, _: progress.progress(i / num_expected_steps), - callback_steps=1, + result_images: T.List[Image.Image] = [] + result_segments: T.List[pydub.AudioSegment] = [] + for i, clip_segment in enumerate(clip_segments): + st.write(f"### Clip {i} at {clip_start_times[i]}s") + + audio_bytes = io.BytesIO() + clip_segment.export(audio_bytes, format="wav") + + init_image = streamlit_util.spectrogram_image_from_audio( + clip_segment, + params=params, + device=device, + ) + + if show_clip_details: + left, right = st.columns(2) + + left.write("##### Source Clip") + left.image(init_image, use_column_width=False) + left.audio(audio_bytes) + + right.write("##### Riffed Clip") + empty_bin = right.empty() + with empty_bin.container(): + st.info("Riffing...") + progress = st.progress(0.0) + + image = streamlit_util.run_img2img( + prompt=prompt, + init_image=init_image, + denoising_strength=denoising_strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + seed=seed, + progress_callback=progress.progress, + device=device, ) - image = result.images[0] result_images.append(image) - row = st.columns(2) - st.write(init_image.size, image.size) - row[0].image(init_image) - row[1].image(image) + if show_clip_details: + empty_bin.empty() + right.image(image, use_column_width=False) - st.info("Converting back into audio clips") - result_segments : T.List[pydub.AudioSegment] = [] - for image in result_images: - result_segments.append(converter.audio_from_spectrogram_image(image)) + riffed_segment = streamlit_util.audio_segment_from_spectrogram_image( + image=image, + params=params, + device=device, + ) + result_segments.append(riffed_segment) + + audio_bytes = io.BytesIO() + riffed_segment.export(audio_bytes, format="wav") + + if show_clip_details: + right.audio(audio_bytes) # Combine clips with a crossfade based on overlap crossfade_ms = int(overlap_duration_s * 1000) diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index 6512860..7c3e89d 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -57,7 +57,6 @@ def load_stable_diffusion_pipeline( ).to(device) - @st.experimental_singleton def load_stable_diffusion_img2img_pipeline( checkpoint: str = "riffusion/riffusion-model-v1", @@ -121,6 +120,26 @@ def spectrogram_image_converter( return SpectrogramImageConverter(params=params, device=device) +@st.cache +def spectrogram_image_from_audio( + segment: pydub.AudioSegment, + params: SpectrogramParams, + device: str = "cuda", +) -> Image.Image: + converter = spectrogram_image_converter(params=params, device=device) + return converter.spectrogram_image_from_audio(segment) + + +@st.experimental_memo +def audio_segment_from_spectrogram_image( + image: Image.Image, + params: SpectrogramParams, + device: str = "cuda", +) -> pydub.AudioSegment: + converter = spectrogram_image_converter(params=params, device=device) + return converter.audio_from_spectrogram_image(image) + + @st.experimental_memo def audio_bytes_from_spectrogram_image( image: Image.Image, @@ -128,12 +147,10 @@ def audio_bytes_from_spectrogram_image( device: str = "cuda", output_format: str = "mp3", ) -> io.BytesIO: - converter = spectrogram_image_converter(params=params, device=device) - segment = converter.audio_from_spectrogram_image(image) + segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device) audio_bytes = io.BytesIO() segment.export(audio_bytes, format=output_format) - audio_bytes.seek(0) return audio_bytes @@ -165,3 +182,41 @@ def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: @st.experimental_singleton def get_audio_splitter(device: str = "cuda"): return AudioSplitter(device=device) + + +@st.cache +def run_img2img( + prompt: str, + init_image: Image.Image, + denoising_strength: float, + num_inference_steps: int, + guidance_scale: float, + negative_prompt: str, + seed: int, + device: str = "cuda", + progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, +) -> Image.Image: + pipeline = load_stable_diffusion_img2img_pipeline(device=device) + + generator = torch.Generator(device="cpu").manual_seed(seed) + + num_expected_steps = max(int(num_inference_steps * denoising_strength), 1) + + def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None: + if progress_callback is not None: + progress_callback(step / num_expected_steps) + + result = pipeline( + prompt=prompt, + image=init_image, + strength=denoising_strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt or None, + num_images_per_prompt=1, + generator=generator, + callback=callback, + callback_steps=1, + ) + + return result.images[0]