# Schedulers Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx). Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample, schedulers define the whole denoising process, *i.e.*: - How many denoising steps? - Stochastic or deterministic? - What algorithm to use to find the denoised sample They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**. It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best. The following paragraphs shows how to do so with the 🧨 Diffusers library. ## Load pipeline Let's start by loading the stable diffusion pipeline. Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion. ```python from huggingface_hub import login from diffusers import DiffusionPipeline import torch # first we need to login with our access token login() # Now we can download the pipeline pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) ``` Next, we move it to GPU: ```python pipeline.to("cuda") ``` ## Access the scheduler The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`. So it can be accessed via the `"scheduler"` property. ```python pipeline.scheduler ``` **Output**: ``` PNDMScheduler { "_class_name": "PNDMScheduler", "_diffusers_version": "0.8.0.dev0", "beta_end": 0.012, "beta_schedule": "scaled_linear", "beta_start": 0.00085, "clip_sample": false, "num_train_timesteps": 1000, "set_alpha_to_one": false, "skip_prk_steps": true, "steps_offset": 1, "trained_betas": null } ``` We can see that the scheduler is of type [`PNDMScheduler`]. Cool, now let's compare the scheduler in its performance to other schedulers. First we define a prompt on which we will test all the different schedulers: ```python prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition." ``` Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline: ```python generator = torch.Generator(device="cuda").manual_seed(8) image = pipeline(prompt, generator=generator).images[0] image ```



## Changing the scheduler Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`] which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows. ```python pipeline.scheduler.compatibles ``` **Output**: ``` [diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler, diffusers.schedulers.scheduling_ddim.DDIMScheduler, diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler, diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler, diffusers.schedulers.scheduling_pndm.PNDMScheduler, diffusers.schedulers.scheduling_ddpm.DDPMScheduler, diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler] ``` Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions: - [`LMSDiscreteScheduler`], - [`DDIMScheduler`], - [`DPMSolverMultistepScheduler`], - [`EulerDiscreteScheduler`], - [`PNDMScheduler`], - [`DDPMScheduler`], - [`EulerAncestralDiscreteScheduler`]. We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function. ```python pipeline.scheduler.config ``` returns a dictionary of the configuration of the scheduler: **Output**: ``` FrozenDict([('num_train_timesteps', 1000), ('beta_start', 0.00085), ('beta_end', 0.012), ('beta_schedule', 'scaled_linear'), ('trained_betas', None), ('skip_prk_steps', True), ('set_alpha_to_one', False), ('steps_offset', 1), ('_class_name', 'PNDMScheduler'), ('_diffusers_version', '0.8.0.dev0'), ('clip_sample', False)]) ``` This configuration can then be used to instantiate a scheduler of a different class that is compatible with the pipeline. Here, we change the scheduler to the [`DDIMScheduler`]. ```python from diffusers import DDIMScheduler pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) ``` Cool, now we can run the pipeline again to compare the generation quality. ```python generator = torch.Generator(device="cuda").manual_seed(8) image = pipeline(prompt, generator=generator).images[0] image ```



## Compare schedulers So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`]. A number of better schedulers have been released that can be run with much fewer steps, let's compare them here: [`LMSDiscreteScheduler`] usually leads to better results: ```python from diffusers import LMSDiscreteScheduler pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) generator = torch.Generator(device="cuda").manual_seed(8) image = pipeline(prompt, generator=generator).images[0] image ```



[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps. ```python from diffusers import EulerDiscreteScheduler pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) generator = torch.Generator(device="cuda").manual_seed(8) image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] image ```



and: ```python from diffusers import EulerAncestralDiscreteScheduler pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) generator = torch.Generator(device="cuda").manual_seed(8) image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] image ```



At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little as 20 steps. ```python from diffusers import DPMSolverMultistepScheduler pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) generator = torch.Generator(device="cuda").manual_seed(8) image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0] image ```



As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different schedulers to compare results.