medvram support for SD3
This commit is contained in:
parent
a65dd315ad
commit
a8fba9af35
|
@ -1,9 +1,12 @@
|
|||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from modules import devices, shared
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
|
||||
|
||||
def send_everything_to_cpu():
|
||||
global module_in_gpu
|
||||
|
@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
(sd_model, 'depth_model'),
|
||||
(sd_model, 'embedder'),
|
||||
(sd_model, 'model'),
|
||||
(sd_model, 'embedder'),
|
||||
]
|
||||
|
||||
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||
|
||||
if is_sdxl:
|
||||
if hasattr(sd_model, 'medvram_fields'):
|
||||
to_remain_in_cpu = sd_model.medvram_fields()
|
||||
elif is_sdxl:
|
||||
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||
elif is_sd2:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||
|
@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
setattr(obj, field, module)
|
||||
|
||||
# register hooks for those the first three models
|
||||
if is_sdxl:
|
||||
if hasattr(sd_model.cond_stage_model, "medvram_modules"):
|
||||
for module in sd_model.cond_stage_model.medvram_modules():
|
||||
if isinstance(module, ModuleWithParent):
|
||||
parent = module.parent
|
||||
module = module.module
|
||||
else:
|
||||
parent = None
|
||||
|
||||
if module:
|
||||
module.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
if parent:
|
||||
parents[module] = parent
|
||||
|
||||
elif is_sdxl:
|
||||
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||
elif is_sd2:
|
||||
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
if sd_model.depth_model:
|
||||
if hasattr(sd_model, 'depth_model'):
|
||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
if sd_model.embedder:
|
||||
if hasattr(sd_model, 'embedder'):
|
||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
if use_medvram:
|
||||
|
|
|
@ -492,7 +492,6 @@ class MMDiT(nn.Module):
|
|||
device = None,
|
||||
):
|
||||
super().__init__()
|
||||
print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}")
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
|
|
|
@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
|
|||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.tensor([[0]], device=devices.device) # XXX
|
||||
|
||||
def medvram_modules(self):
|
||||
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||
|
||||
|
||||
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
||||
def __init__(self, inner_model, sigmas):
|
||||
|
@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
|
|||
return self.cond_stage_model(batch)
|
||||
|
||||
def apply_model(self, x, t, cond):
|
||||
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
||||
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
||||
|
||||
def decode_first_stage(self, latent):
|
||||
latent = self.latent_format.process_out(latent)
|
||||
|
@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
|
|||
|
||||
def create_denoiser(self):
|
||||
return SD3Denoiser(self, self.model.model_sampling.sigmas)
|
||||
|
||||
def medvram_fields(self):
|
||||
return [
|
||||
(self, 'first_stage_model'),
|
||||
(self, 'cond_stage_model'),
|
||||
(self, 'model'),
|
||||
]
|
||||
|
|
|
@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
|
|||
else:
|
||||
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
|
||||
try:
|
||||
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
|
||||
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
|
||||
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
|
||||
timestep = torch.max(sigma).to(dtype=int)
|
||||
completed_ratio = (999 - timestep) / 1000
|
||||
|
|
Loading…
Reference in New Issue