Efficient Attention (#366)

* up

* add tests

* correct

* up

* finish

* better naming

* Update README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-09-06 18:06:47 +02:00 committed by GitHub
parent 56c003705f
commit 5c4ea00de7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 221 additions and 20 deletions

View File

@ -104,7 +104,9 @@ with autocast("cuda"):
image = pipe(prompt).images[0] image = pipe(prompt).images[0]
``` ```
If you are limited by GPU memory, you might want to consider using the model in `fp16`. If you are limited by GPU memory, you might want to consider using the model in `fp16` as
well as chunking the attention computation.
The following snippet should result in less than 4GB VRAM.
```python ```python
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
@ -116,6 +118,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
pipe = pipe.to("cuda") pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt).images[0] image = pipe(prompt).images[0]
``` ```

View File

@ -63,18 +63,19 @@ class AttentionBlock(nn.Module):
# get scores # get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output # compute attention output
context_states = torch.matmul(attention_probs, value_states) hidden_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous() hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,) new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape) hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states # compute next hidden_states
hidden_states = self.proj_attn(context_states) hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale # res connect and rescale
@ -107,6 +108,10 @@ class SpatialTransformer(nn.Module):
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape b, c, h, w = x.shape
@ -136,6 +141,10 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def forward(self, x, context=None): def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x x = self.attn2(self.norm2(x), context=context) + x
@ -151,6 +160,10 @@ class CrossAttention(nn.Module):
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
@ -175,8 +188,6 @@ class CrossAttention(nn.Module):
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = context if context is not None else x context = context if context is not None else x
k = self.to_k(context) k = self.to_k(context)
@ -186,20 +197,33 @@ class CrossAttention(nn.Module):
k = self.reshape_heads_to_batch_dim(k) k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v) v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale # TODO(PVP) - mask is currently never used. Remember to re-implement when used
if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
attn = sim.softmax(dim=-1) hidden_states = self._attention(q, k, v, sequence_length, dim)
out = torch.einsum("b i j, b j d -> b i d", attn, v) return self.to_out(hidden_states)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out) def _attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
class FeedForward(nn.Module): class FeedForward(nn.Module):

View File

@ -133,6 +133,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
if slice_size is not None and slice_size > self.config.attention_head_dim:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
self.mid_block.set_attention_slice(slice_size)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,

View File

@ -295,6 +295,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
super().__init__() super().__init__()
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet # there is always at least one resnet
@ -342,6 +343,21 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
@ -457,6 +473,7 @@ class CrossAttnDownBlock2D(nn.Module):
attentions = [] attentions = []
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
@ -497,6 +514,21 @@ class CrossAttnDownBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
@ -989,6 +1021,7 @@ class CrossAttnUpBlock2D(nn.Module):
attentions = [] attentions = []
self.attention_type = attention_type self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers): for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@ -1025,6 +1058,21 @@ class CrossAttnUpBlock2D(nn.Module):
else: else:
self.upsamplers = None self.upsamplers = None
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):

View File

@ -36,6 +36,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,

View File

@ -47,6 +47,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,

View File

@ -61,6 +61,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,

View File

@ -153,7 +153,6 @@ class PipelineFastTests(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
config = CLIPTextConfig( config = CLIPTextConfig(
bos_token_id=0, bos_token_id=0,
chunk_size_feed_forward=0,
eos_token_id=2, eos_token_id=2,
hidden_size=32, hidden_size=32,
intermediate_size=37, intermediate_size=37,
@ -410,6 +409,38 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_attention_chunk(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
# make sure chunking the attention yields the same result
sd_pipe.enable_attention_slicing(slice_size=1)
generator = torch.Generator(device=device).manual_seed(0)
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = ScoreSdeVeScheduler(tensor_format="pt") scheduler = ScoreSdeVeScheduler(tensor_format="pt")
@ -1045,6 +1076,46 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024]) expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_memory_chunking(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
).to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
# make attention efficient
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 3.75 GB is allocated
assert mem_bytes < 3.75 * 10**9
# disable chunking
pipe.disable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
# make sure that more than 3.75 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline(self): def test_stable_diffusion_img2img_pipeline(self):