From bd4da4474bef5c9c1f690c62b971704ee73d2860 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 13 Aug 2023 02:27:39 +0800 Subject: [PATCH] Add extra norm module into built-in lora ext refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style) --- extensions-builtin/Lora/network.py | 7 +- extensions-builtin/Lora/network_norm.py | 29 +++++++++ extensions-builtin/Lora/networks.py | 64 ++++++++++++++++--- .../Lora/scripts/lora_script.py | 16 +++++ 4 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 extensions-builtin/Lora/network_norm.py diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 0a18d69eb..b7b890618 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -133,7 +133,7 @@ class NetworkModule: return 1.0 - def finalize_updown(self, updown, orig_weight, output_shape): + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) @@ -145,7 +145,10 @@ class NetworkModule: if orig_weight.size().numel() == updown.size().numel(): updown = updown.reshape(orig_weight.shape) - return updown * self.calc_scale() * self.multiplier() + if ex_bias is None: + ex_bias = 0 + + return updown * self.calc_scale() * self.multiplier(), ex_bias * self.multiplier() def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 000000000..dab8b684d --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,29 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + print("NetworkModuleNorm") + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7e3415acb..74cefe435 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -7,6 +7,7 @@ import network_hada import network_ia3 import network_lokr import network_full +import network_norm import torch from typing import Union @@ -19,6 +20,7 @@ module_types = [ network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), + network_norm.ModuleTypeNorm(), ] @@ -31,6 +33,8 @@ suffix_conversion = { "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } @@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No purge_networks_from_memory() -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): weights_backup = getattr(self, "network_weights_backup", None) + bias_backup = getattr(self, "network_bias_backup", None) - if weights_backup is None: + if weights_backup is None and bias_backup is None: return - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) + if weights_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + if bias_backup is not None: + self.bias.copy_(bias_backup) -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing. @@ -294,6 +303,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_weights_backup = weights_backup + bias_backup = getattr(self, "network_bias_backup", None) + if bias_backup is None and getattr(self, 'bias', None) is not None: + bias_backup = self.bias.to(devices.cpu, copy=True) + self.network_bias_backup = bias_backup + if current_names != wanted_names: network_restore_weights_from_backup(self) @@ -301,13 +315,15 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): with torch.no_grad(): - updown = module.calc_updown(self.weight) + updown, ex_bias = module.calc_updown(self.weight) if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) self.weight += updown + if getattr(self, 'bias', None) is not None: + self.bias += ex_bias continue module_q = net.modules.get(network_layer_name + "_q_proj", None) @@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs): return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) +def network_GroupNorm_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.GroupNorm_forward_before_network) + + network_apply_weights(self) + + return torch.nn.GroupNorm_forward_before_network(self, input) + + +def network_GroupNorm_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs) + + +def network_LayerNorm_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.LayerNorm_forward_before_network) + + network_apply_weights(self) + + return torch.nn.LayerNorm_forward_before_network(self, input) + + +def network_LayerNorm_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs) + + def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 6ab8b6e7c..dc307f8cc 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict +if not hasattr(torch.nn, 'GroupNorm_forward_before_network'): + torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward + +if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'): + torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict + +if not hasattr(torch.nn, 'LayerNorm_forward_before_network'): + torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward + +if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'): + torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict + if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward @@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict torch.nn.Conv2d.forward = networks.network_Conv2d_forward torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict +torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward +torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict +torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward +torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict