Fix typo in AttnProcessor2_0 symbol (#2404)
Fix typo in AttnProcessor2_0 symbol.
This commit is contained in:
parent
07547dfacd
commit
780b3a4f8c
|
@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
|
|||
```Python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models.cross_attention import AttnProccesor2_0
|
||||
from diffusers.models.cross_attention import AttnProcessor2_0
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.set_attn_processor(AttnProccesor2_0())
|
||||
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
image = pipe(prompt).images[0]
|
||||
|
|
|
@ -99,10 +99,10 @@ class CrossAttention(nn.Module):
|
|||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
# set attention processor
|
||||
# We use the AttnProccesor2_0 by default when torch2.x is used which uses
|
||||
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
if processor is None:
|
||||
processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
|
||||
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
|
@ -466,10 +466,10 @@ class XFormersCrossAttnProcessor:
|
|||
return hidden_states
|
||||
|
||||
|
||||
class AttnProccesor2_0:
|
||||
class AttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, inner_dim = hidden_states.shape
|
||||
|
|
Loading…
Reference in New Issue