diff --git a/riffusion/streamlit/pages/audio_to_audio.py b/riffusion/streamlit/pages/audio_to_audio.py index 35067c6..b1271df 100644 --- a/riffusion/streamlit/pages/audio_to_audio.py +++ b/riffusion/streamlit/pages/audio_to_audio.py @@ -46,18 +46,27 @@ def render_audio_to_audio() -> None: device = streamlit_util.select_device(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar) - num_inference_steps = T.cast( - int, - st.sidebar.number_input( - "Steps per sample", value=50, help="Number of denoising steps per model run" - ), - ) + with st.sidebar: + num_inference_steps = T.cast( + int, + st.number_input( + "Steps per sample", value=50, help="Number of denoising steps per model run" + ), + ) - guidance = st.sidebar.number_input( - "Guidance", - value=7.0, - help="How much the model listens to the text prompt", - ) + guidance = st.number_input( + "Guidance", + value=7.0, + help="How much the model listens to the text prompt", + ) + + scheduler = st.selectbox( + "Scheduler", + options=streamlit_util.SCHEDULER_OPTIONS, + index=0, + help="Which diffusion scheduler to use", + ) + assert scheduler is not None audio_file = st.file_uploader( "Upload audio", @@ -207,6 +216,7 @@ def render_audio_to_audio() -> None: seed=prompt_input_a.seed, progress_callback=progress_callback, device=device, + scheduler=scheduler, ) # Resize back to original size diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 48037e1..3d41534 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -61,6 +61,13 @@ def render_text_to_audio() -> None: guidance = st.number_input( "Guidance", value=7.0, help="How much the model listens to the text prompt" ) + scheduler = st.selectbox( + "Scheduler", + options=streamlit_util.SCHEDULER_OPTIONS, + index=0, + help="Which diffusion scheduler to use", + ) + assert scheduler is not None if not prompt: st.info("Enter a prompt") @@ -85,6 +92,7 @@ def render_text_to_audio() -> None: width=width, height=512, device=device, + scheduler=scheduler, ) st.image(image) diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index 74e276b..c1fce73 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -20,6 +20,15 @@ from riffusion.spectrogram_params import SpectrogramParams AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"] IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"] +SCHEDULER_OPTIONS = [ + "PNDMScheduler", + "DDIMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", +] + @st.experimental_singleton def load_riffusion_checkpoint( @@ -42,6 +51,7 @@ def load_stable_diffusion_pipeline( checkpoint: str = "riffusion/riffusion-model-v1", device: str = "cuda", dtype: torch.dtype = torch.float16, + scheduler: str = SCHEDULER_OPTIONS[0], ) -> StableDiffusionPipeline: """ Load the riffusion pipeline. @@ -52,19 +62,56 @@ def load_stable_diffusion_pipeline( print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") dtype = torch.float32 - return StableDiffusionPipeline.from_pretrained( + pipeline = StableDiffusionPipeline.from_pretrained( checkpoint, revision="main", torch_dtype=dtype, safety_checker=lambda images, **kwargs: (images, False), ).to(device) + pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config) + + return pipeline + + +def get_scheduler(scheduler: str, config: T.Any) -> T.Any: + """ + Construct a denoising scheduler from a string. + """ + if scheduler == "PNDMScheduler": + from diffusers import PNDMScheduler + + return PNDMScheduler.from_config(config) + elif scheduler == "DPMSolverMultistepScheduler": + from diffusers import DPMSolverMultistepScheduler + + return DPMSolverMultistepScheduler.from_config(config) + elif scheduler == "DDIMScheduler": + from diffusers import DDIMScheduler + + return DDIMScheduler.from_config(config) + elif scheduler == "LMSDiscreteScheduler": + from diffusers import LMSDiscreteScheduler + + return LMSDiscreteScheduler.from_config(config) + elif scheduler == "EulerDiscreteScheduler": + from diffusers import EulerDiscreteScheduler + + return EulerDiscreteScheduler.from_config(config) + elif scheduler == "EulerAncestralDiscreteScheduler": + from diffusers import EulerAncestralDiscreteScheduler + + return EulerAncestralDiscreteScheduler.from_config(config) + else: + raise ValueError(f"Unknown scheduler {scheduler}") + @st.experimental_singleton def load_stable_diffusion_img2img_pipeline( checkpoint: str = "riffusion/riffusion-model-v1", device: str = "cuda", dtype: torch.dtype = torch.float16, + scheduler: str = SCHEDULER_OPTIONS[0], ) -> StableDiffusionImg2ImgPipeline: """ Load the image to image pipeline. @@ -75,13 +122,17 @@ def load_stable_diffusion_img2img_pipeline( print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") dtype = torch.float32 - return StableDiffusionImg2ImgPipeline.from_pretrained( + pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( checkpoint, revision="main", torch_dtype=dtype, safety_checker=lambda images, **kwargs: (images, False), ).to(device) + pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config) + + return pipeline + @st.experimental_memo def run_txt2img( @@ -93,11 +144,12 @@ def run_txt2img( width: int, height: int, device: str = "cuda", + scheduler: str = SCHEDULER_OPTIONS[0], ) -> Image.Image: """ Run the text to image pipeline with caching. """ - pipeline = load_stable_diffusion_pipeline(device=device) + pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler) generator_device = "cpu" if device.lower().startswith("mps") else device generator = torch.Generator(device=generator_device).manual_seed(seed) @@ -214,9 +266,10 @@ def run_img2img( seed: int, negative_prompt: T.Optional[str] = None, device: str = "cuda", + scheduler: str = SCHEDULER_OPTIONS[0], progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, ) -> Image.Image: - pipeline = load_stable_diffusion_img2img_pipeline(device=device) + pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler) generator_device = "cpu" if device.lower().startswith("mps") else device generator = torch.Generator(device=generator_device).manual_seed(seed)