diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 67a9a7e..9630ee8 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -28,6 +28,7 @@ def render_text_to_audio() -> None: device = streamlit_util.select_device(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar) + checkpoint = streamlit_util.select_checkpoint(st.sidebar) with st.form("Inputs"): prompt = st.text_input("Prompt") @@ -69,15 +70,25 @@ def render_text_to_audio() -> None: ) assert scheduler is not None + use_20k = st.checkbox("Use 20kHz", value=False) + if not prompt: st.info("Enter a prompt") return - # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained - params = SpectrogramParams( - min_frequency=0, - max_frequency=10000, - ) + if use_20k: + params = SpectrogramParams( + min_frequency=10, + max_frequency=20000, + sample_rate=44100, + stereo=True, + ) + else: + params = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + stereo=False, + ) seed = starting_seed for i in range(1, num_clips + 1): @@ -91,6 +102,7 @@ def render_text_to_audio() -> None: seed=seed, width=width, height=512, + checkpoint=checkpoint, device=device, scheduler=scheduler, ) diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index ae33de9..bb51035 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -18,6 +18,8 @@ from riffusion.spectrogram_params import SpectrogramParams # TODO(hayk): Add URL params +DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1" + AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"] IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"] @@ -33,7 +35,7 @@ SCHEDULER_OPTIONS = [ @st.experimental_singleton def load_riffusion_checkpoint( - checkpoint: str = "riffusion/riffusion-model-v1", + checkpoint: str = DEFAULT_CHECKPOINT, no_traced_unet: bool = False, device: str = "cuda", ) -> RiffusionPipeline: @@ -49,7 +51,7 @@ def load_riffusion_checkpoint( @st.experimental_singleton def load_stable_diffusion_pipeline( - checkpoint: str = "riffusion/riffusion-model-v1", + checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", dtype: torch.dtype = torch.float16, scheduler: str = SCHEDULER_OPTIONS[0], @@ -117,7 +119,7 @@ def pipeline_lock() -> threading.Lock: @st.experimental_singleton def load_stable_diffusion_img2img_pipeline( - checkpoint: str = "riffusion/riffusion-model-v1", + checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", dtype: torch.dtype = torch.float16, scheduler: str = SCHEDULER_OPTIONS[0], @@ -152,6 +154,7 @@ def run_txt2img( seed: int, width: int, height: int, + checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0], ) -> Image.Image: @@ -159,7 +162,11 @@ def run_txt2img( Run the text to image pipeline with caching. """ with pipeline_lock(): - pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler) + pipeline = load_stable_diffusion_pipeline( + checkpoint=checkpoint, + device=device, + scheduler=scheduler, + ) generator_device = "cpu" if device.lower().startswith("mps") else device generator = torch.Generator(device=generator_device).manual_seed(seed) @@ -270,6 +277,18 @@ def select_scheduler(container: T.Any = st.sidebar) -> str: return scheduler +def select_checkpoint(container: T.Any = st.sidebar) -> str: + """ + Provide a custom model checkpoint. + """ + custom_checkpoint = container.text_input( + "Custom Checkpoint", + value="", + help="Provide a custom model checkpoint", + ) + return custom_checkpoint or DEFAULT_CHECKPOINT + + @st.experimental_memo def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: return pydub.AudioSegment.from_file(audio_file) @@ -281,9 +300,13 @@ def get_audio_splitter(device: str = "cuda"): @st.experimental_singleton -def load_magic_mix_pipeline(device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0]): +def load_magic_mix_pipeline( + checkpoint: str = DEFAULT_CHECKPOINT, + device: str = "cuda", + scheduler: str = SCHEDULER_OPTIONS[0], +): pipeline = DiffusionPipeline.from_pretrained( - "riffusion/riffusion-model-v1", + checkpoint, custom_pipeline="magic_mix", ).to(device) @@ -302,6 +325,7 @@ def run_img2img_magic_mix( kmin: float, kmax: float, mix_factor: float, + checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0], ): @@ -310,6 +334,7 @@ def run_img2img_magic_mix( """ with pipeline_lock(): pipeline = load_magic_mix_pipeline( + checkpoint=checkpoint, device=device, scheduler=scheduler, ) @@ -335,12 +360,17 @@ def run_img2img( guidance_scale: float, seed: int, negative_prompt: T.Optional[str] = None, + checkpoint: str = DEFAULT_CHECKPOINT, device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0], progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, ) -> Image.Image: with pipeline_lock(): - pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler) + pipeline = load_stable_diffusion_img2img_pipeline( + checkpoint=checkpoint, + device=device, + scheduler=scheduler, + ) generator_device = "cpu" if device.lower().startswith("mps") else device generator = torch.Generator(device=generator_device).manual_seed(seed)