Add initial scheduler support
Add the ability to choose schedulers, at least for a subset of the pipeline configurations. Allows configuring for text to audio and audio to audio in the sidebar. Currently not used for interpolation, aka the riffusion pipeline. Topic: schedulers_v0
This commit is contained in:
parent
c771ab0d23
commit
b45910709c
|
@ -46,18 +46,27 @@ def render_audio_to_audio() -> None:
|
|||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
num_inference_steps = T.cast(
|
||||
int,
|
||||
st.sidebar.number_input(
|
||||
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
||||
),
|
||||
)
|
||||
with st.sidebar:
|
||||
num_inference_steps = T.cast(
|
||||
int,
|
||||
st.number_input(
|
||||
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
||||
),
|
||||
)
|
||||
|
||||
guidance = st.sidebar.number_input(
|
||||
"Guidance",
|
||||
value=7.0,
|
||||
help="How much the model listens to the text prompt",
|
||||
)
|
||||
guidance = st.number_input(
|
||||
"Guidance",
|
||||
value=7.0,
|
||||
help="How much the model listens to the text prompt",
|
||||
)
|
||||
|
||||
scheduler = st.selectbox(
|
||||
"Scheduler",
|
||||
options=streamlit_util.SCHEDULER_OPTIONS,
|
||||
index=0,
|
||||
help="Which diffusion scheduler to use",
|
||||
)
|
||||
assert scheduler is not None
|
||||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload audio",
|
||||
|
@ -207,6 +216,7 @@ def render_audio_to_audio() -> None:
|
|||
seed=prompt_input_a.seed,
|
||||
progress_callback=progress_callback,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# Resize back to original size
|
||||
|
|
|
@ -61,6 +61,13 @@ def render_text_to_audio() -> None:
|
|||
guidance = st.number_input(
|
||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||
)
|
||||
scheduler = st.selectbox(
|
||||
"Scheduler",
|
||||
options=streamlit_util.SCHEDULER_OPTIONS,
|
||||
index=0,
|
||||
help="Which diffusion scheduler to use",
|
||||
)
|
||||
assert scheduler is not None
|
||||
|
||||
if not prompt:
|
||||
st.info("Enter a prompt")
|
||||
|
@ -85,6 +92,7 @@ def render_text_to_audio() -> None:
|
|||
width=width,
|
||||
height=512,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
st.image(image)
|
||||
|
||||
|
|
|
@ -20,6 +20,15 @@ from riffusion.spectrogram_params import SpectrogramParams
|
|||
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
||||
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
||||
|
||||
SCHEDULER_OPTIONS = [
|
||||
"PNDMScheduler",
|
||||
"DDIMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_riffusion_checkpoint(
|
||||
|
@ -42,6 +51,7 @@ def load_stable_diffusion_pipeline(
|
|||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
|
@ -52,19 +62,56 @@ def load_stable_diffusion_pipeline(
|
|||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
return StableDiffusionPipeline.from_pretrained(
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
|
||||
"""
|
||||
Construct a denoising scheduler from a string.
|
||||
"""
|
||||
if scheduler == "PNDMScheduler":
|
||||
from diffusers import PNDMScheduler
|
||||
|
||||
return PNDMScheduler.from_config(config)
|
||||
elif scheduler == "DPMSolverMultistepScheduler":
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
return DPMSolverMultistepScheduler.from_config(config)
|
||||
elif scheduler == "DDIMScheduler":
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
return DDIMScheduler.from_config(config)
|
||||
elif scheduler == "LMSDiscreteScheduler":
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
return LMSDiscreteScheduler.from_config(config)
|
||||
elif scheduler == "EulerDiscreteScheduler":
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
return EulerDiscreteScheduler.from_config(config)
|
||||
elif scheduler == "EulerAncestralDiscreteScheduler":
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
return EulerAncestralDiscreteScheduler.from_config(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler {scheduler}")
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_stable_diffusion_img2img_pipeline(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
) -> StableDiffusionImg2ImgPipeline:
|
||||
"""
|
||||
Load the image to image pipeline.
|
||||
|
@ -75,13 +122,17 @@ def load_stable_diffusion_img2img_pipeline(
|
|||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
return StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def run_txt2img(
|
||||
|
@ -93,11 +144,12 @@ def run_txt2img(
|
|||
width: int,
|
||||
height: int,
|
||||
device: str = "cuda",
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Run the text to image pipeline with caching.
|
||||
"""
|
||||
pipeline = load_stable_diffusion_pipeline(device=device)
|
||||
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler)
|
||||
|
||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
@ -214,9 +266,10 @@ def run_img2img(
|
|||
seed: int,
|
||||
negative_prompt: T.Optional[str] = None,
|
||||
device: str = "cuda",
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
||||
) -> Image.Image:
|
||||
pipeline = load_stable_diffusion_img2img_pipeline(device=device)
|
||||
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler)
|
||||
|
||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
|
Loading…
Reference in New Issue