Fix typo in AttnProcessor2_0 symbol (#2404)
Fix typo in AttnProcessor2_0 symbol.
This commit is contained in:
parent
07547dfacd
commit
780b3a4f8c
|
@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
|
||||||
```Python
|
```Python
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
from diffusers.models.cross_attention import AttnProccesor2_0
|
from diffusers.models.cross_attention import AttnProcessor2_0
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
|
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
|
||||||
pipe.unet.set_attn_processor(AttnProccesor2_0())
|
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
|
|
||||||
prompt = "a photo of an astronaut riding a horse on mars"
|
prompt = "a photo of an astronaut riding a horse on mars"
|
||||||
image = pipe(prompt).images[0]
|
image = pipe(prompt).images[0]
|
||||||
|
|
|
@ -99,10 +99,10 @@ class CrossAttention(nn.Module):
|
||||||
self.to_out.append(nn.Dropout(dropout))
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
# set attention processor
|
# set attention processor
|
||||||
# We use the AttnProccesor2_0 by default when torch2.x is used which uses
|
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
|
||||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||||
if processor is None:
|
if processor is None:
|
||||||
processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
|
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
|
||||||
self.set_processor(processor)
|
self.set_processor(processor)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(
|
def set_use_memory_efficient_attention_xformers(
|
||||||
|
@ -466,10 +466,10 @@ class XFormersCrossAttnProcessor:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class AttnProccesor2_0:
|
class AttnProcessor2_0:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
batch_size, sequence_length, inner_dim = hidden_states.shape
|
batch_size, sequence_length, inner_dim = hidden_states.shape
|
||||||
|
|
Loading…
Reference in New Issue