Add option for float32 sampling with float16 UNet

This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half().
This commit is contained in:
brkirch 2023-01-24 23:51:45 -05:00
parent 48a15821de
commit 84d9ce30cb
8 changed files with 82 additions and 8 deletions

View File

@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- (You)

View File

@ -2,6 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import devices
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
t_360 = self.n_Conv_0(t_359_padded)
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361)

View File

@ -79,6 +79,8 @@ cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
def randn(seed, shape):

View File

@ -172,7 +172,8 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@ -203,7 +204,7 @@ class StableDiffusionProcessing:
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@ -211,7 +212,7 @@ class StableDiffusionProcessing:
)
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@ -225,10 +226,10 @@ class StableDiffusionProcessing:
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image)
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
with devices.autocast(disable=devices.unet_needs_upcast):
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(shared.device)
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

View File

@ -1,4 +1,8 @@
import torch
from packaging import version
from modules import devices
from modules.sd_hijack_utils import CondFunc
class TorchHijackForUnet:
@ -28,3 +32,28 @@ class TorchHijackForUnet:
th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x):
if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
else:
return torch.nn.GELU.forward(self, x)
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)

View File

@ -0,0 +1,28 @@
import importlib
class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-2, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)

View File

@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
model.half()
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@ -372,6 +380,8 @@ def load_model(checkpoint_info=None):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
timer = Timer()

View File

@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")