Fix typo in AttnProcessor2_0 symbol (#2404)

Fix typo in AttnProcessor2_0 symbol.
This commit is contained in:
Pedro Cuenca 2023-02-17 21:21:18 +01:00 committed by GitHub
parent 07547dfacd
commit 780b3a4f8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProccesor2_0
from diffusers.models.cross_attention import AttnProcessor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProccesor2_0())
pipe.unet.set_attn_processor(AttnProcessor2_0())
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

View File

@ -99,10 +99,10 @@ class CrossAttention(nn.Module):
self.to_out.append(nn.Dropout(dropout))
# set attention processor
# We use the AttnProccesor2_0 by default when torch2.x is used which uses
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
if processor is None:
processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
@ -466,10 +466,10 @@ class XFormersCrossAttnProcessor:
return hidden_states
class AttnProccesor2_0:
class AttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, inner_dim = hidden_states.shape