diff --git a/riffusion/streamlit/pages/sample_clips.py b/riffusion/streamlit/pages/sample_clips.py index 43df1ff..9acf3c1 100644 --- a/riffusion/streamlit/pages/sample_clips.py +++ b/riffusion/streamlit/pages/sample_clips.py @@ -6,6 +6,7 @@ import numpy as np import pydub import streamlit as st +from riffusion.spectrogram_params import SpectrogramParams from riffusion.streamlit import util as streamlit_util @@ -51,9 +52,11 @@ def render_sample_clips() -> None: ) ) + device = streamlit_util.select_device(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar) save_to_disk = st.sidebar.checkbox("Save to Disk", False) export_as_mono = st.sidebar.checkbox("Export as Mono", False) + compute_spectrograms = st.sidebar.checkbox("Compute Spectrograms", False) row = st.columns(4) num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3)) @@ -72,12 +75,19 @@ def render_sample_clips() -> None: output_path.mkdir(parents=True, exist_ok=True) st.info(f"Output directory: `{output_dir}`") + if compute_spectrograms: + images_dir = output_path / "images" + images_dir.mkdir(parents=True, exist_ok=True) + if seed >= 0: np.random.seed(seed) if export_as_mono and segment.channels > 1: segment = segment.set_channels(1) + if save_to_disk: + st.info(f"Writing {num_clips} clip(s) to `{str(output_path)}`") + # TODO(hayk): Share code with riffusion.cli.sample_clips. segment_duration_ms = int(segment.duration_seconds * 1000) for i in range(num_clips): @@ -95,9 +105,24 @@ def render_sample_clips() -> None: ) if save_to_disk: - clip_path = output_path / f"clip_name.{extension}" + clip_path = output_path / f"{clip_name}.{extension}" clip.export(clip_path, format=extension) + if compute_spectrograms: + params = SpectrogramParams() + + image = streamlit_util.spectrogram_image_from_audio( + clip, + params=params, + device=device, + ) + + st.image(image) + + if save_to_disk: + image_path = images_dir / f"{clip_name}.jpeg" + image.save(image_path) + if save_to_disk: st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")