246 lines
9.1 KiB
Plaintext
246 lines
9.1 KiB
Plaintext
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||
the License. You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||
specific language governing permissions and limitations under the License.
|
||
-->
|
||
|
||
# Memory and speed
|
||
|
||
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
|
||
|
||
|
||
| | Latency | Speedup |
|
||
|------------------|---------|---------|
|
||
| original | 9.50s | x1 |
|
||
| cuDNN auto-tuner | 9.37s | x1.01 |
|
||
| autocast (fp16) | 5.47s | x1.91 |
|
||
| fp16 | 3.61s | x2.91 |
|
||
| channels last | 3.30s | x2.87 |
|
||
| traced UNet | 3.21s | x2.96 |
|
||
|
||
<em>obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps.</em>
|
||
|
||
## Enable cuDNN auto-tuner
|
||
|
||
[NVIDIA cuDNN](https://developer.nvidia.com/cudnn) supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.
|
||
|
||
Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting:
|
||
|
||
```python
|
||
import torch
|
||
|
||
torch.backends.cudnn.benchmark = True
|
||
```
|
||
|
||
### Use tf32 instead of fp32 (on Ampere and later CUDA devices)
|
||
|
||
On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:
|
||
|
||
```python
|
||
import torch
|
||
|
||
torch.backends.cuda.matmul.allow_tf32 = True
|
||
```
|
||
|
||
## Automatic mixed precision (AMP)
|
||
|
||
If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example:
|
||
|
||
```Python
|
||
from torch import autocast
|
||
from diffusers import StableDiffusionPipeline
|
||
|
||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
|
||
pipe = pipe.to("cuda")
|
||
|
||
prompt = "a photo of an astronaut riding a horse on mars"
|
||
with autocast("cuda"):
|
||
image = pipe(prompt).images[0]
|
||
```
|
||
|
||
Despite the precision loss, in our experience the final image results look the same as the `float32` versions. Feel free to experiment and report back!
|
||
|
||
## Half precision weights
|
||
|
||
To save more GPU memory, you can load the model weights directly in half precision. This involves loading the float16 version of the weights, which was saved to a branch named `fp16`, and telling PyTorch to use the `float16` type when loading them:
|
||
|
||
```Python
|
||
pipe = StableDiffusionPipeline.from_pretrained(
|
||
"CompVis/stable-diffusion-v1-4",
|
||
revision="fp16",
|
||
torch_dtype=torch.float16,
|
||
use_auth_token=True
|
||
)
|
||
```
|
||
|
||
## Sliced attention for additional memory savings
|
||
|
||
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
|
||
|
||
<Tip>
|
||
Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory.
|
||
</Tip>
|
||
|
||
To perform the attention computation sequentially over each head, you only need to invoke [`~StableDiffusionPipeline.enable_attention_slicing`] in your pipeline before inference, like here:
|
||
|
||
```Python
|
||
import torch
|
||
from diffusers import StableDiffusionPipeline
|
||
|
||
pipe = StableDiffusionPipeline.from_pretrained(
|
||
"CompVis/stable-diffusion-v1-4",
|
||
revision="fp16",
|
||
torch_dtype=torch.float16,
|
||
use_auth_token=True
|
||
)
|
||
pipe = pipe.to("cuda")
|
||
|
||
prompt = "a photo of an astronaut riding a horse on mars"
|
||
pipe.enable_attention_slicing()
|
||
with torch.autocast("cuda"):
|
||
image = pipe(prompt).images[0]
|
||
```
|
||
|
||
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
|
||
|
||
## Using Channels Last memory format
|
||
|
||
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
|
||
|
||
For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:
|
||
|
||
```python
|
||
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
|
||
pipe.unet.to(memory_format=torch.channels_last) # in-place operation
|
||
print(
|
||
pipe.unet.conv_out.state_dict()["weight"].stride()
|
||
) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
|
||
```
|
||
|
||
## Tracing
|
||
|
||
Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.
|
||
|
||
To trace our UNet model, we can use the following:
|
||
|
||
```python
|
||
import time
|
||
import torch
|
||
from diffusers import StableDiffusionPipeline
|
||
import functools
|
||
|
||
# torch disable grad
|
||
torch.set_grad_enabled(False)
|
||
|
||
# set variables
|
||
n_experiments = 2
|
||
unet_runs_per_experiment = 50
|
||
|
||
# load inputs
|
||
def generate_inputs():
|
||
sample = torch.randn(2, 4, 64, 64).half().cuda()
|
||
timestep = torch.rand(1).half().cuda() * 999
|
||
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
|
||
return sample, timestep, encoder_hidden_states
|
||
|
||
|
||
pipe = StableDiffusionPipeline.from_pretrained(
|
||
"CompVis/stable-diffusion-v1-4",
|
||
# scheduler=scheduler,
|
||
use_auth_token=True,
|
||
revision="fp16",
|
||
torch_dtype=torch.float16,
|
||
).to("cuda")
|
||
unet = pipe.unet
|
||
unet.eval()
|
||
unet.to(memory_format=torch.channels_last) # use channels_last memory format
|
||
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
|
||
|
||
# warmup
|
||
for _ in range(3):
|
||
with torch.inference_mode():
|
||
inputs = generate_inputs()
|
||
orig_output = unet(*inputs)
|
||
|
||
# trace
|
||
print("tracing..")
|
||
unet_traced = torch.jit.trace(unet, inputs)
|
||
unet_traced.eval()
|
||
print("done tracing")
|
||
|
||
|
||
# warmup and optimize graph
|
||
for _ in range(5):
|
||
with torch.inference_mode():
|
||
inputs = generate_inputs()
|
||
orig_output = unet_traced(*inputs)
|
||
|
||
|
||
# benchmarking
|
||
with torch.inference_mode():
|
||
for _ in range(n_experiments):
|
||
torch.cuda.synchronize()
|
||
start_time = time.time()
|
||
for _ in range(unet_runs_per_experiment):
|
||
orig_output = unet_traced(*inputs)
|
||
torch.cuda.synchronize()
|
||
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
|
||
for _ in range(n_experiments):
|
||
torch.cuda.synchronize()
|
||
start_time = time.time()
|
||
for _ in range(unet_runs_per_experiment):
|
||
orig_output = unet(*inputs)
|
||
torch.cuda.synchronize()
|
||
print(f"unet inference took {time.time() - start_time:.2f} seconds")
|
||
|
||
# save the model
|
||
unet_traced.save("unet_traced.pt")
|
||
```
|
||
|
||
Then we can replace the `unet` attribute of the pipeline with the traced model like the following
|
||
|
||
```python
|
||
from diffusers import StableDiffusionPipeline
|
||
import torch
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class UNet2DConditionOutput:
|
||
sample: torch.FloatTensor
|
||
|
||
|
||
pipe = StableDiffusionPipeline.from_pretrained(
|
||
"CompVis/stable-diffusion-v1-4",
|
||
# scheduler=scheduler,
|
||
use_auth_token=True,
|
||
revision="fp16",
|
||
torch_dtype=torch.float16,
|
||
).to("cuda")
|
||
|
||
# use jitted unet
|
||
unet_traced = torch.jit.load("unet_traced.pt")
|
||
# del pipe.unet
|
||
class TracedUNet(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.in_channels = pipe.unet.in_channels
|
||
self.device = pipe.unet.device
|
||
|
||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||
return UNet2DConditionOutput(sample=sample)
|
||
|
||
|
||
pipe.unet = TracedUNet()
|
||
|
||
with torch.inference_mode():
|
||
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
|
||
```
|