Add initial scheduler support

Add the ability to choose schedulers, at least for a subset of the pipeline
configurations. Allows configuring for text to audio and audio to audio
in the sidebar. Currently not used for interpolation, aka the riffusion
pipeline.

Topic: schedulers_v0
This commit is contained in:
Hayk Martiros 2023-01-14 22:52:02 +00:00
parent c771ab0d23
commit b45910709c
3 changed files with 86 additions and 15 deletions

View File

@ -46,18 +46,27 @@ def render_audio_to_audio() -> None:
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)
num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
with st.sidebar:
num_inference_steps = T.cast(
int,
st.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
guidance = st.sidebar.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)
guidance = st.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)
scheduler = st.selectbox(
"Scheduler",
options=streamlit_util.SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
audio_file = st.file_uploader(
"Upload audio",
@ -207,6 +216,7 @@ def render_audio_to_audio() -> None:
seed=prompt_input_a.seed,
progress_callback=progress_callback,
device=device,
scheduler=scheduler,
)
# Resize back to original size

View File

@ -61,6 +61,13 @@ def render_text_to_audio() -> None:
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
scheduler = st.selectbox(
"Scheduler",
options=streamlit_util.SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
if not prompt:
st.info("Enter a prompt")
@ -85,6 +92,7 @@ def render_text_to_audio() -> None:
width=width,
height=512,
device=device,
scheduler=scheduler,
)
st.image(image)

View File

@ -20,6 +20,15 @@ from riffusion.spectrogram_params import SpectrogramParams
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
SCHEDULER_OPTIONS = [
"PNDMScheduler",
"DDIMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@st.experimental_singleton
def load_riffusion_checkpoint(
@ -42,6 +51,7 @@ def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
@ -52,19 +62,56 @@ def load_stable_diffusion_pipeline(
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
return StableDiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
return pipeline
def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
"""
Construct a denoising scheduler from a string.
"""
if scheduler == "PNDMScheduler":
from diffusers import PNDMScheduler
return PNDMScheduler.from_config(config)
elif scheduler == "DPMSolverMultistepScheduler":
from diffusers import DPMSolverMultistepScheduler
return DPMSolverMultistepScheduler.from_config(config)
elif scheduler == "DDIMScheduler":
from diffusers import DDIMScheduler
return DDIMScheduler.from_config(config)
elif scheduler == "LMSDiscreteScheduler":
from diffusers import LMSDiscreteScheduler
return LMSDiscreteScheduler.from_config(config)
elif scheduler == "EulerDiscreteScheduler":
from diffusers import EulerDiscreteScheduler
return EulerDiscreteScheduler.from_config(config)
elif scheduler == "EulerAncestralDiscreteScheduler":
from diffusers import EulerAncestralDiscreteScheduler
return EulerAncestralDiscreteScheduler.from_config(config)
else:
raise ValueError(f"Unknown scheduler {scheduler}")
@st.experimental_singleton
def load_stable_diffusion_img2img_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionImg2ImgPipeline:
"""
Load the image to image pipeline.
@ -75,13 +122,17 @@ def load_stable_diffusion_img2img_pipeline(
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
return StableDiffusionImg2ImgPipeline.from_pretrained(
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
return pipeline
@st.experimental_memo
def run_txt2img(
@ -93,11 +144,12 @@ def run_txt2img(
width: int,
height: int,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
pipeline = load_stable_diffusion_pipeline(device=device)
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
@ -214,9 +266,10 @@ def run_img2img(
seed: int,
negative_prompt: T.Optional[str] = None,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
) -> Image.Image:
pipeline = load_stable_diffusion_img2img_pipeline(device=device)
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)