* Fix regression introduced in #2448 * Style.
This commit is contained in:
parent
fa6d52d594
commit
2ea1da89ab
|
@ -150,7 +150,7 @@ class UNet2DConditionLoadersMixin:
|
||||||
|
|
||||||
model_file = None
|
model_file = None
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
if is_safetensors_available():
|
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
|
||||||
if weight_name is None:
|
if weight_name is None:
|
||||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -445,6 +445,43 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
# LoRA and no LoRA should NOT be the same
|
# LoRA and no LoRA should NOT be the same
|
||||||
assert (sample - old_sample).abs().max() > 1e-4
|
assert (sample - old_sample).abs().max() > 1e-4
|
||||||
|
|
||||||
|
def test_lora_save_load_safetensors_load_torch(self):
|
||||||
|
# enable deterministic behavior for gradient checkpointing
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["attention_head_dim"] = (8, 16)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
lora_attn_procs = {}
|
||||||
|
for name in model.attn_processors.keys():
|
||||||
|
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||||
|
if name.startswith("mid_block"):
|
||||||
|
hidden_size = model.config.block_out_channels[-1]
|
||||||
|
elif name.startswith("up_blocks"):
|
||||||
|
block_id = int(name[len("up_blocks.")])
|
||||||
|
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||||
|
elif name.startswith("down_blocks"):
|
||||||
|
block_id = int(name[len("down_blocks.")])
|
||||||
|
hidden_size = model.config.block_out_channels[block_id]
|
||||||
|
|
||||||
|
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||||
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||||
|
)
|
||||||
|
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||||
|
|
||||||
|
model.set_attn_processor(lora_attn_procs)
|
||||||
|
# Saving as torch, properly reloads with directly filename
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_attn_procs(tmpdirname)
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||||
|
torch.manual_seed(0)
|
||||||
|
new_model = self.model_class(**init_dict)
|
||||||
|
new_model.to(torch_device)
|
||||||
|
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
|
||||||
|
|
||||||
def test_lora_on_off(self):
|
def test_lora_on_off(self):
|
||||||
# enable deterministic behavior for gradient checkpointing
|
# enable deterministic behavior for gradient checkpointing
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
Loading…
Reference in New Issue