diff --git a/modules/extras.py b/modules/extras.py index 6a0d5cb0f..1d9e64e55 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs.append(image) + devices.torch_gc() + return outputs, plaintext_to_html(info), '' diff --git a/modules/processing.py b/modules/processing.py index e567956ce..de818d5b9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,7 +11,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers +from modules import devices, prompt_parser, masking, sd_samplers, lowvram from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -382,6 +382,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + del samples_ddim + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + + devices.torch_gc() + if opts.filter_nsfw: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) @@ -426,6 +433,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) + del x_samples_ddim + + devices.torch_gc() + state.nextjob() p.color_corrections = None @@ -663,4 +674,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask + del x + devices.torch_gc() + return samples diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3fa062422..a6fa890c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,6 +5,7 @@ import traceback import torch import numpy as np from torch import einsum +from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared @@ -19,11 +20,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + ldm.modules.diffusionmodules.model.nonlinearity = silu + if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9c079e578..ea4cfdfcd 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -92,14 +92,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_)