medvram support for SD3
This commit is contained in:
parent
a65dd315ad
commit
a8fba9af35
|
@ -1,9 +1,12 @@
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import devices, shared
|
from modules import devices, shared
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
|
|
||||||
|
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
|
||||||
|
|
||||||
def send_everything_to_cpu():
|
def send_everything_to_cpu():
|
||||||
global module_in_gpu
|
global module_in_gpu
|
||||||
|
@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
(sd_model, 'depth_model'),
|
(sd_model, 'depth_model'),
|
||||||
(sd_model, 'embedder'),
|
(sd_model, 'embedder'),
|
||||||
(sd_model, 'model'),
|
(sd_model, 'model'),
|
||||||
(sd_model, 'embedder'),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
is_sdxl = hasattr(sd_model, 'conditioner')
|
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||||
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||||
|
|
||||||
if is_sdxl:
|
if hasattr(sd_model, 'medvram_fields'):
|
||||||
|
to_remain_in_cpu = sd_model.medvram_fields()
|
||||||
|
elif is_sdxl:
|
||||||
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||||
elif is_sd2:
|
elif is_sd2:
|
||||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||||
|
@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
setattr(obj, field, module)
|
setattr(obj, field, module)
|
||||||
|
|
||||||
# register hooks for those the first three models
|
# register hooks for those the first three models
|
||||||
if is_sdxl:
|
if hasattr(sd_model.cond_stage_model, "medvram_modules"):
|
||||||
|
for module in sd_model.cond_stage_model.medvram_modules():
|
||||||
|
if isinstance(module, ModuleWithParent):
|
||||||
|
parent = module.parent
|
||||||
|
module = module.module
|
||||||
|
else:
|
||||||
|
parent = None
|
||||||
|
|
||||||
|
if module:
|
||||||
|
module.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
|
||||||
|
if parent:
|
||||||
|
parents[module] = parent
|
||||||
|
|
||||||
|
elif is_sdxl:
|
||||||
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||||
elif is_sd2:
|
elif is_sd2:
|
||||||
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||||
if sd_model.depth_model:
|
if hasattr(sd_model, 'depth_model'):
|
||||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
if sd_model.embedder:
|
if hasattr(sd_model, 'embedder'):
|
||||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
|
||||||
if use_medvram:
|
if use_medvram:
|
||||||
|
|
|
@ -492,7 +492,6 @@ class MMDiT(nn.Module):
|
||||||
device = None,
|
device = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}")
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.learn_sigma = learn_sigma
|
self.learn_sigma = learn_sigma
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
|
@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
return torch.tensor([[0]], device=devices.device) # XXX
|
return torch.tensor([[0]], device=devices.device) # XXX
|
||||||
|
|
||||||
|
def medvram_modules(self):
|
||||||
|
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||||
|
|
||||||
|
|
||||||
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
||||||
def __init__(self, inner_model, sigmas):
|
def __init__(self, inner_model, sigmas):
|
||||||
|
@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
|
||||||
return self.cond_stage_model(batch)
|
return self.cond_stage_model(batch)
|
||||||
|
|
||||||
def apply_model(self, x, t, cond):
|
def apply_model(self, x, t, cond):
|
||||||
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
||||||
|
|
||||||
def decode_first_stage(self, latent):
|
def decode_first_stage(self, latent):
|
||||||
latent = self.latent_format.process_out(latent)
|
latent = self.latent_format.process_out(latent)
|
||||||
|
@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
|
||||||
|
|
||||||
def create_denoiser(self):
|
def create_denoiser(self):
|
||||||
return SD3Denoiser(self, self.model.model_sampling.sigmas)
|
return SD3Denoiser(self, self.model.model_sampling.sigmas)
|
||||||
|
|
||||||
|
def medvram_fields(self):
|
||||||
|
return [
|
||||||
|
(self, 'first_stage_model'),
|
||||||
|
(self, 'cond_stage_model'),
|
||||||
|
(self, 'model'),
|
||||||
|
]
|
||||||
|
|
|
@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
|
||||||
else:
|
else:
|
||||||
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
|
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
|
||||||
try:
|
try:
|
||||||
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
|
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
|
||||||
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
|
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
|
||||||
timestep = torch.max(sigma).to(dtype=int)
|
timestep = torch.max(sigma).to(dtype=int)
|
||||||
completed_ratio = (999 - timestep) / 1000
|
completed_ratio = (999 - timestep) / 1000
|
||||||
|
|
Loading…
Reference in New Issue