diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index b79e0fda..665ac6ce 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ```Python import torch from diffusers import StableDiffusionPipeline - from diffusers.models.cross_attention import AttnProccesor2_0 + from diffusers.models.cross_attention import AttnProcessor2_0 pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") - pipe.unet.set_attn_processor(AttnProccesor2_0()) + pipe.unet.set_attn_processor(AttnProcessor2_0()) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 0580ee2c..822405b8 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,10 +99,10 @@ class CrossAttention(nn.Module): self.to_out.append(nn.Dropout(dropout)) # set attention processor - # We use the AttnProccesor2_0 by default when torch2.x is used which uses + # We use the AttnProcessor2_0 by default when torch2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention if processor is None: - processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -466,10 +466,10 @@ class XFormersCrossAttnProcessor: return hidden_states -class AttnProccesor2_0: +class AttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, inner_dim = hidden_states.shape