Add general forward method for all modules.
This commit is contained in:
parent
a06dab8d7a
commit
18ca987c92
|
@ -3,6 +3,10 @@ import os
|
|||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
|
@ -115,6 +119,29 @@ class NetworkModule:
|
|||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.ops = None
|
||||
self.extra_kwargs = {}
|
||||
if isinstance(self.sd_module, nn.Conv2d):
|
||||
self.ops = F.conv2d
|
||||
self.extra_kwargs = {
|
||||
'stride': self.sd_module.stride,
|
||||
'padding': self.sd_module.padding
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.Linear):
|
||||
self.ops = F.linear
|
||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||
self.ops = F.layer_norm
|
||||
self.extra_kwargs = {
|
||||
'normalized_shape': self.sd_module.normalized_shape,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||
self.ops = F.group_norm
|
||||
self.extra_kwargs = {
|
||||
'num_groups': self.sd_module.num_groups,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
|
@ -155,5 +182,10 @@ class NetworkModule:
|
|||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""A general forward implementation for all modules"""
|
||||
if self.ops is None:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||
|
||||
|
|
|
@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
self.network_current_names = wanted_names
|
||||
|
||||
|
||||
def network_forward(module, input, original_forward):
|
||||
def network_forward(org_module, input, original_forward):
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
|
||||
if len(loaded_networks) == 0:
|
||||
return original_forward(module, input)
|
||||
return original_forward(org_module, input)
|
||||
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
network_restore_weights_from_backup(module)
|
||||
network_reset_cached_weight(module)
|
||||
network_restore_weights_from_backup(org_module)
|
||||
network_reset_cached_weight(org_module)
|
||||
|
||||
y = original_forward(module, input)
|
||||
y = original_forward(org_module, input)
|
||||
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
network_layer_name = getattr(org_module, 'network_layer_name', None)
|
||||
for lora in loaded_networks:
|
||||
module = lora.modules.get(network_layer_name, None)
|
||||
if module is None:
|
||||
|
|
Loading…
Reference in New Issue