thank you linter
This commit is contained in:
parent
6c5f83b19b
commit
dc39061856
|
@ -229,9 +229,9 @@ def load_lora(name, lora_on_disk):
|
||||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
module.weight.copy_(weight)
|
module.weight.copy_(weight)
|
||||||
|
@ -243,7 +243,7 @@ def load_lora(name, lora_on_disk):
|
||||||
elif lora_key == "lora_down.weight":
|
elif lora_key == "lora_down.weight":
|
||||||
lora_module.down = module
|
lora_module.down = module
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
|
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
|
||||||
|
|
Loading…
Reference in New Issue