diff --git a/modules/lowvram.py b/modules/lowvram.py index 45701046b..00aad477b 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,9 +1,12 @@ +from collections import namedtuple + import torch from modules import devices, shared module_in_gpu = None cpu = torch.device("cpu") +ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None']) def send_everything_to_cpu(): global module_in_gpu @@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram): (sd_model, 'depth_model'), (sd_model, 'embedder'), (sd_model, 'model'), - (sd_model, 'embedder'), ] is_sdxl = hasattr(sd_model, 'conditioner') is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') - if is_sdxl: + if hasattr(sd_model, 'medvram_fields'): + to_remain_in_cpu = sd_model.medvram_fields() + elif is_sdxl: to_remain_in_cpu.append((sd_model, 'conditioner')) elif is_sd2: to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) @@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram): setattr(obj, field, module) # register hooks for those the first three models - if is_sdxl: + if hasattr(sd_model.cond_stage_model, "medvram_modules"): + for module in sd_model.cond_stage_model.medvram_modules(): + if isinstance(module, ModuleWithParent): + parent = module.parent + module = module.module + else: + parent = None + + if module: + module.register_forward_pre_hook(send_me_to_gpu) + + if parent: + parents[module] = parent + + elif is_sdxl: sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) elif is_sd2: sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) @@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap - if sd_model.depth_model: + if hasattr(sd_model, 'depth_model'): sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) - if sd_model.embedder: + if hasattr(sd_model, 'embedder'): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) if use_medvram: diff --git a/modules/models/sd3/mmdit.py b/modules/models/sd3/mmdit.py index 5ec73c054..4d2b85551 100644 --- a/modules/models/sd3/mmdit.py +++ b/modules/models/sd3/mmdit.py @@ -492,7 +492,6 @@ class MMDiT(nn.Module): device = None, ): super().__init__() - print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}") self.dtype = dtype self.learn_sigma = learn_sigma self.in_channels = in_channels diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py index 146ddf2e2..309a7f863 100644 --- a/modules/models/sd3/sd3_model.py +++ b/modules/models/sd3/sd3_model.py @@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module): def encode_embedding_init_text(self, init_text, nvpt): return torch.tensor([[0]], device=devices.device) # XXX + def medvram_modules(self): + return [self.clip_g, self.clip_l, self.t5xxl] + class SD3Denoiser(k_diffusion.external.DiscreteSchedule): def __init__(self, inner_model, sigmas): @@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module): return self.cond_stage_model(batch) def apply_model(self, x, t, cond): - return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) + return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) def decode_first_stage(self, latent): latent = self.latent_format.process_out(latent) @@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module): def create_denoiser(self): return SD3Denoiser(self, self.model.model_sampling.sigmas) + + def medvram_fields(self): + return [ + (self, 'first_stage_model'), + (self, 'cond_stage_model'), + (self, 'model'), + ] diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index b584b68a9..c060cccb2 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None): else: # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch try: - timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) + timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma))) except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep timestep = torch.max(sigma).to(dtype=int) completed_ratio = (999 - timestep) / 1000