Optimize Stable Diffusion (#371)

* initial commit

* make UNet stream capturable

* try to fix noise_pred value

* remove cuda graph and keep NB

* non blocking unet with PNDMScheduler

* make timesteps np arrays for pndm scheduler
because lists don't get formatted to tensors in `self.set_format`

* make max async in pndm

* use channel last format in unet

* avoid moving timesteps device in each unet call

* avoid memcpy op in `get_timestep_embedding`

* add `channels_last` kwarg to `DiffusionPipeline.from_pretrained`

* update TODO

* replace `channels_last` kwarg with `memory_format` for more generality

* revert the channels_last changes to leave it for another PR

* remove non_blocking when moving input ids to device

* remove blocking from all .to() operations at beginning of pipeline

* fix merging

* fix merging

* model can run in other precisions without autocast

* attn refactoring

* Revert "attn refactoring"

This reverts commit 0c70c0e189cd2c4d8768274c9fcf5b940ee310fb.

* remove restriction to run conv_norm in fp32

* use `baddbmm` instead of `matmul`for better in attention for better perf

* removing all reshapes to test perf

* Revert "removing all reshapes to test perf"

This reverts commit 006ccb8a8c6bc7eb7e512392e692a29d9b1553cd.

* add shapes comments

* hardcore whats needed for jitting

* Revert "hardcore whats needed for jitting"

This reverts commit 2fa9c698eae2890ac5f8e367ca80532ecf94df9a.

* Revert "remove restriction to run conv_norm in fp32"

This reverts commit cec592890c32da3d1b78d38b49e4307aedf459b9.

* revert using baddmm in attention's forward

* cleanup comment

* remove restriction to run conv_norm in fp32. no quality loss was noticed

This reverts commit cc9bc1339c998ebe9e7d733f910c6d72d9792213.

* add more optimizations techniques to docs

* Revert "add shapes comments"

This reverts commit 31c58eadb8892f95478cdf05229adf678678c5f4.

* apply suggestions

* make quality

* apply suggestions

* styling

* `scheduler.timesteps` are now arrays so we dont need .to()

* remove useless .type()

* use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms`

* move scheduler timestamps to correct device if tensors

* add device to `set_timesteps` in LMSD scheduler

* `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it

* quick fix

* styling

* remove kwargs from schedulers `set_timesteps`

* revert to using max in K-LMS inpaint pipeline test

* Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it"

This reverts commit 00d5a51e5c20d8d445c8664407ef29608106d899.

* move timesteps to correct device before loop in SD pipeline

* apply previous fix to other SD pipelines

* UNet now accepts tensor timesteps even on wrong device, to avoid errors
- it shouldnt affect performance if timesteps are alrdy on correct device
- it does slow down performance if they're on the wrong device

* fix pipeline when timesteps are arrays with strides
This commit is contained in:
Nouamane Tazi 2022-09-30 08:49:13 +01:00 committed by GitHub
parent a7058f42e1
commit 9ebaea545f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 244 additions and 23 deletions

View File

@ -14,7 +14,64 @@ specific language governing permissions and limitations under the License.
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
## CUDA `autocast`
<table>
<tr>
<td>
<td>Latency
<td>Speedup
<tr>
<tr>
<td>original
<td>9.50s
<td>x1
<tr>
<tr>
<td>cuDNN auto-tuner
<td>9.37s
<td>x1.01
<tr>
<td>autocast (fp16)
<td>5.47s
<td>x1.91
<tr>
<td>fp16
<td>3.61s
<td>x2.91
<tr>
<td>channels last
<td>3.30s
<td>x2.87
<tr>
<tr>
<td>traced UNet
<td>3.21s
<td>x2.96
</table>
<em>obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps.</em>
## Enable cuDNN auto-tuner
[NVIDIA cuDNN](https://developer.nvidia.com/cudnn) supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.
Since were using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting:
```python
import torch
torch.backends.cudnn.benchmark = True
```
### Use tf32 instead of fp32 (on Ampere and later CUDA devices)
On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:
```python
import torch
torch.backends.cuda.matmul.allow_tf32 = True
```
## Automatic mixed precision (AMP)
If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example:
@ -47,7 +104,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
## Sliced attention for additional memory savings
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
<Tip>
Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory.
@ -73,4 +130,139 @@ with torch.autocast("cuda"):
image = pipe(prompt).images[0]
```
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
## Using Channels Last memory format
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:
```python
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
pipe.unet.to(memory_format=torch.channels_last) # in-place operation
print(
pipe.unet.conv_out.state_dict()["weight"].stride()
) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works
```
## Tracing
Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.
To trace our UNet model, we can use the following:
```python
import time
import torch
from diffusers import StableDiffusionPipeline
import functools
# torch disable grad
torch.set_grad_enabled(False)
# set variables
n_experiments = 2
unet_runs_per_experiment = 50
# load inputs
def generate_inputs():
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
return sample, timestep, encoder_hidden_states
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
# scheduler=scheduler,
use_auth_token=True,
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
unet = pipe.unet
unet.eval()
unet.to(memory_format=torch.channels_last) # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
# warmup
for _ in range(3):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")
# warmup and optimize graph
for _ in range(5):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
# benchmarking
with torch.inference_mode():
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
print(f"unet inference took {time.time() - start_time:.2f} seconds")
# save the model
unet_traced.save("unet_traced.pt")
```
Then we can replace the `unet` attribute of the pipeline with the traced model like the following
```python
from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass
@dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
# scheduler=scheduler,
use_auth_token=True,
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")
# del pipe.unet
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipe.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```

View File

@ -72,8 +72,7 @@ class AttentionBlock(nn.Module):
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
@ -275,7 +274,13 @@ class CrossAttention(nn.Module):
return self.to_out(hidden_states)
def _attention(self, query, key, value):
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
@ -292,7 +297,9 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])

View File

@ -37,10 +37,12 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(device=timesteps.device)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings

View File

@ -331,7 +331,7 @@ class ResnetBlock2D(nn.Module):
# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
@ -349,7 +349,7 @@ class ResnetBlock2D(nn.Module):
# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)

View File

@ -230,16 +230,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps.to(dtype=torch.float32)
timesteps = timesteps[None].to(device=sample.device)
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb.to(self.dtype))
# 2. pre-process
sample = self.conv_in(sample)
@ -279,7 +279,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

View File

@ -225,15 +225,23 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents_shape,
generator=generator,
device=latents_device,
dtype=text_embeddings.dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(latents_device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
if torch.is_tensor(self.scheduler.timesteps):
timesteps_tensor = self.scheduler.timesteps.to(self.device)
else:
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
@ -247,7 +255,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
@ -278,7 +286,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@ -265,7 +265,11 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance

View File

@ -298,7 +298,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
for i, t in tqdm(enumerate(timesteps_tensor)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

View File

@ -131,13 +131,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return integrated_coeff
def set_timesteps(self, num_inference_steps: int):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
@ -145,8 +147,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.derivatives = []