Make sub-quadratic the default for MPS
This commit is contained in:
parent
abfa4ad8bc
commit
87dd685224
|
@ -95,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
name = "sub-quadratic"
|
name = "sub-quadratic"
|
||||||
cmd_opt = "opt_sub_quad_attention"
|
cmd_opt = "opt_sub_quad_attention"
|
||||||
priority = 10
|
|
||||||
|
@property
|
||||||
|
def priority(self):
|
||||||
|
return 1000 if shared.device.type == 'mps' else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
|
@ -121,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self):
|
def priority(self):
|
||||||
return 1000 if not torch.cuda.is_available() else 10
|
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
|
|
Loading…
Reference in New Issue