diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index 109b4c2c5..bf6930e96 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule): super().__init__(net, weights) self.weight = weights.w.get("diff") + self.ex_bias = weights.w.get("diff_b") def calc_updown(self, orig_weight): output_shape = self.weight.shape updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + if self.ex_bias is not None: + ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None - return self.finalize_updown(updown, orig_weight, output_shape) + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)