[LoRA] Make sure LoRA can be disabled after it's run (#2128)
This commit is contained in:
parent
e92d43feb0
commit
f653ded7ed
|
@ -17,9 +17,13 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
from ..utils.import_utils import is_xformers_available
|
from ..utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
@ -151,6 +155,16 @@ class CrossAttention(nn.Module):
|
||||||
self.set_processor(processor)
|
self.set_processor(processor)
|
||||||
|
|
||||||
def set_processor(self, processor: "AttnProcessor"):
|
def set_processor(self, processor: "AttnProcessor"):
|
||||||
|
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||||
|
# pop `processor` from `self._modules`
|
||||||
|
if (
|
||||||
|
hasattr(self, "processor")
|
||||||
|
and isinstance(self.processor, torch.nn.Module)
|
||||||
|
and not isinstance(processor, torch.nn.Module)
|
||||||
|
):
|
||||||
|
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||||
|
self._modules.pop("processor")
|
||||||
|
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
||||||
|
|
|
@ -20,7 +20,7 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor
|
||||||
from diffusers.utils import (
|
from diffusers.utils import (
|
||||||
floats_tensor,
|
floats_tensor,
|
||||||
load_hf_numpy,
|
load_hf_numpy,
|
||||||
|
@ -40,6 +40,34 @@ logger = logging.get_logger(__name__)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_layers(model):
|
||||||
|
lora_attn_procs = {}
|
||||||
|
for name in model.attn_processors.keys():
|
||||||
|
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||||
|
if name.startswith("mid_block"):
|
||||||
|
hidden_size = model.config.block_out_channels[-1]
|
||||||
|
elif name.startswith("up_blocks"):
|
||||||
|
block_id = int(name[len("up_blocks.")])
|
||||||
|
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||||
|
elif name.startswith("down_blocks"):
|
||||||
|
block_id = int(name[len("down_blocks.")])
|
||||||
|
hidden_size = model.config.block_out_channels[block_id]
|
||||||
|
|
||||||
|
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||||
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||||
|
)
|
||||||
|
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||||
|
|
||||||
|
# add 1 to weights to mock trained weights
|
||||||
|
with torch.no_grad():
|
||||||
|
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||||
|
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||||
|
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||||
|
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||||
|
|
||||||
|
return lora_attn_procs
|
||||||
|
|
||||||
|
|
||||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
model_class = UNet2DConditionModel
|
model_class = UNet2DConditionModel
|
||||||
|
|
||||||
|
@ -336,30 +364,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
old_sample = model(**inputs_dict).sample
|
old_sample = model(**inputs_dict).sample
|
||||||
|
|
||||||
lora_attn_procs = {}
|
lora_attn_procs = create_lora_layers(model)
|
||||||
for name in model.attn_processors.keys():
|
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
|
||||||
if name.startswith("mid_block"):
|
|
||||||
hidden_size = model.config.block_out_channels[-1]
|
|
||||||
elif name.startswith("up_blocks"):
|
|
||||||
block_id = int(name[len("up_blocks.")])
|
|
||||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
|
||||||
elif name.startswith("down_blocks"):
|
|
||||||
block_id = int(name[len("down_blocks.")])
|
|
||||||
hidden_size = model.config.block_out_channels[block_id]
|
|
||||||
|
|
||||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
|
||||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
|
||||||
)
|
|
||||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
|
||||||
|
|
||||||
# add 1 to weights to mock trained weights
|
|
||||||
with torch.no_grad():
|
|
||||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
|
||||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
|
||||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
|
||||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
|
||||||
|
|
||||||
model.set_attn_processor(lora_attn_procs)
|
model.set_attn_processor(lora_attn_procs)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -380,6 +385,33 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
# LoRA and no LoRA should NOT be the same
|
# LoRA and no LoRA should NOT be the same
|
||||||
assert (sample - old_sample).abs().max() > 1e-4
|
assert (sample - old_sample).abs().max() > 1e-4
|
||||||
|
|
||||||
|
def test_lora_on_off(self):
|
||||||
|
# enable deterministic behavior for gradient checkpointing
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
old_sample = model(**inputs_dict).sample
|
||||||
|
|
||||||
|
lora_attn_procs = create_lora_layers(model)
|
||||||
|
model.set_attn_processor(lora_attn_procs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
|
||||||
|
|
||||||
|
model.set_attn_processor(CrossAttnProcessor())
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
new_sample = model(**inputs_dict).sample
|
||||||
|
|
||||||
|
assert (sample - new_sample).abs().max() < 1e-4
|
||||||
|
assert (sample - old_sample).abs().max() < 1e-4
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in New Issue