From c938679de7b87b4f14894d9f57fe0f40dd6e3c06 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Wed, 28 Sep 2022 22:14:13 -0300 Subject: [PATCH 1/3] Fix memory leak and reduce memory usage --- modules/codeformer_model.py | 6 ++++-- modules/devices.py | 3 ++- modules/extras.py | 2 ++ modules/gfpgan_model.py | 11 +++++------ modules/processing.py | 33 ++++++++++++++++++++++++++------- webui.py | 3 +++ 6 files changed, 42 insertions(+), 16 deletions(-) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 8fbdea249..2177291a7 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -89,7 +89,7 @@ def setup_codeformer(): output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output - torch.cuda.empty_cache() + devices.torch_gc() except Exception as error: print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) @@ -106,7 +106,9 @@ def setup_codeformer(): restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) if shared.opts.face_restoration_unload: - self.net.to(devices.cpu) + self.net = None + self.face_helper = None + devices.torch_gc() return restored_img diff --git a/modules/devices.py b/modules/devices.py index 07bb23397..df63dd88e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,4 +1,5 @@ import torch +import gc # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility from modules import errors @@ -17,8 +18,8 @@ def get_optimal_device(): return cpu - def torch_gc(): + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/modules/extras.py b/modules/extras.py index 9a825530f..38b861675 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -98,6 +98,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs.append(image) + devices.torch_gc() + return outputs, plaintext_to_html(info), '' diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 44c5dc6ca..b1288f0ca 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -49,6 +49,7 @@ def gfpgan(): def gfpgan_fix_faces(np_image): + global loaded_gfpgan_model model = gfpgan() np_image_bgr = np_image[:, :, ::-1] @@ -56,7 +57,9 @@ def gfpgan_fix_faces(np_image): np_image = gfpgan_output_bgr[:, :, ::-1] if shared.opts.face_restoration_unload: - model.gfpgan.to(devices.cpu) + del model + loaded_gfpgan_model = None + devices.torch_gc() return np_image @@ -83,11 +86,7 @@ def setup_gfpgan(): return "GFPGAN" def restore(self, np_image): - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - return np_image + return gfpgan_fix_faces(np_image) shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: diff --git a/modules/processing.py b/modules/processing.py index 4ecdfcd2d..de5cda793 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, lowvram from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): + for n in range(p.n_iter): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if state.interrupted: break @@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + del samples_ddim + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + + devices.torch_gc() + if opts.filter_nsfw: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): + for i, x_sample in enumerate(x_samples_ddim): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: + if p.restore_faces: + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - devices.torch_gc() - x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - state.nextjob() + del x_samples_ddim + devices.torch_gc() + + state.nextjob() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 @@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask + del x + devices.torch_gc() + return samples diff --git a/webui.py b/webui.py index c70a11c7c..b61a318db 100644 --- a/webui.py +++ b/webui.py @@ -22,7 +22,10 @@ import modules.txt2img import modules.img2img import modules.swinir as swinir import modules.sd_models +from torch.nn.functional import silu +import ldm +ldm.modules.diffusionmodules.model.nonlinearity = silu modules.codeformer_model.setup_codeformer() modules.gfpgan_model.setup_gfpgan() From c2d5b29040132c171bc4d77f1f63da972306f22c Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Thu, 29 Sep 2022 01:14:54 -0300 Subject: [PATCH 2/3] Move silu to sd_hijack --- modules/sd_hijack.py | 12 +++--------- webui.py | 3 --- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bfbd07f9a..4bc58fa2b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -12,6 +12,7 @@ from ldm.util import default from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +from torch.nn.functional import silu # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion @@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -245,11 +238,12 @@ class StableDiffusionModelHijack: m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model + ldm.modules.diffusionmodules.model.nonlinearity = silu + if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def flatten(el): diff --git a/webui.py b/webui.py index b61a318db..c70a11c7c 100644 --- a/webui.py +++ b/webui.py @@ -22,10 +22,7 @@ import modules.txt2img import modules.img2img import modules.swinir as swinir import modules.sd_models -from torch.nn.functional import silu -import ldm -ldm.modules.diffusionmodules.model.nonlinearity = silu modules.codeformer_model.setup_codeformer() modules.gfpgan_model.setup_gfpgan() From 82380d9ac18614c87bebba1b4cfd4b147cc76a18 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 4 Oct 2022 22:28:50 -0300 Subject: [PATCH 3/3] Removing parts no longer needed to fix vram --- modules/devices.py | 3 +-- modules/processing.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6db4e57c9..0158b11fc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ import contextlib import torch -import gc from modules import errors @@ -20,8 +19,8 @@ def get_optimal_device(): return cpu + def torch_gc(): - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/modules/processing.py b/modules/processing.py index e7f9c85e1..f666ba811 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -345,8 +345,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for n in range(p.n_iter): if state.interrupted: break @@ -395,22 +394,19 @@ def process_images(p: StableDiffusionProcessing) -> Processed: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") + devices.torch_gc() + x_sample = modules.face_restoration.restore_faces(x_sample) devices.torch_gc() - devices.torch_gc() - - with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -438,13 +434,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - del x_samples_ddim + del x_samples_ddim - devices.torch_gc() + devices.torch_gc() - state.nextjob() + state.nextjob() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0