alternate implementation for unet forward replacement that does not depend on hijack being applied
This commit is contained in:
parent
af5f0734c9
commit
ac02216e54
|
@ -38,8 +38,11 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
|
||||||
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
|
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
|
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||||
|
|
||||||
|
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||||
|
|
||||||
|
|
||||||
def list_optimizers():
|
def list_optimizers():
|
||||||
|
|
|
@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
|
||||||
unet_options = []
|
unet_options = []
|
||||||
current_unet_option = None
|
current_unet_option = None
|
||||||
current_unet = None
|
current_unet = None
|
||||||
original_forward = None
|
original_forward = None # not used, only left temporarily for compatibility
|
||||||
|
|
||||||
|
|
||||||
def list_unets():
|
def list_unets():
|
||||||
new_unets = script_callbacks.list_unets_callback()
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
|
@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
def create_unet_forward(original_forward):
|
||||||
if current_unet is not None:
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
if current_unet is not None:
|
||||||
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
|
return UNetModel_forward
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue