Implement Hypertile
Co-Authored-By: Kieran Hunt <kph@hotmail.ca>
This commit is contained in:
parent
294f8a514f
commit
b29fc6d4de
|
@ -0,0 +1,333 @@
|
||||||
|
"""
|
||||||
|
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
||||||
|
Warn : The patch works well only if the input image has a width and height that are multiples of 128
|
||||||
|
Author : @tfernd Github : https://github.com/tfernd/HyperTile
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Callable
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import wraps, cache
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch.nn as nn
|
||||||
|
import random
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
# TODO add SD-XL layers
|
||||||
|
DEPTH_LAYERS = {
|
||||||
|
0: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.1.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.2.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.9.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.10.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.11.1.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 VAE
|
||||||
|
"decoder.mid_block.attentions.0",
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.6.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
2: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
3: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"middle_block.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# XL layers, thanks for GitHub@gel-crabs for the help
|
||||||
|
DEPTH_LAYERS_XL = {
|
||||||
|
0: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 VAE
|
||||||
|
"decoder.mid_block.attentions.0",
|
||||||
|
"decoder.mid.attn_1",
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.2.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.2.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.3.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.3.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.4.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.4.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.5.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.5.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.6.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.6.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.7.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.7.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.8.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.8.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.9.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.9.attn1",
|
||||||
|
],
|
||||||
|
2: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"middle_block.1.transformer_blocks.0.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.1.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.2.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.3.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.4.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.5.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.6.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.7.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.8.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.9.attn1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
RNG_INSTANCE = random.Random()
|
||||||
|
|
||||||
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
|
"""
|
||||||
|
Returns a random divisor of value that
|
||||||
|
x * min_value <= value
|
||||||
|
if max_options is 1, the behavior is deterministic
|
||||||
|
"""
|
||||||
|
min_value = min(min_value, value)
|
||||||
|
|
||||||
|
# All big divisors of value (inclusive)
|
||||||
|
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
|
||||||
|
|
||||||
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
|
||||||
|
|
||||||
|
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
|
||||||
|
|
||||||
|
return ns[idx]
|
||||||
|
|
||||||
|
def set_hypertile_seed(seed: int) -> None:
|
||||||
|
RNG_INSTANCE.seed(seed)
|
||||||
|
|
||||||
|
def largest_tile_size_available(width:int, height:int) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the largest tile size available for a given width and height
|
||||||
|
Tile size is always a power of 2
|
||||||
|
"""
|
||||||
|
gcd = math.gcd(width, height)
|
||||||
|
largest_tile_size_available = 1
|
||||||
|
while gcd % (largest_tile_size_available * 2) == 0:
|
||||||
|
largest_tile_size_available *= 2
|
||||||
|
return largest_tile_size_available
|
||||||
|
|
||||||
|
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
We check all possible divisors of hw and return the closest to the aspect ratio
|
||||||
|
"""
|
||||||
|
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
|
||||||
|
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
|
||||||
|
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
|
||||||
|
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
|
||||||
|
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
||||||
|
return closest_pair
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
"""
|
||||||
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
# find h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
if h * w != hw:
|
||||||
|
w_candidate = hw / h
|
||||||
|
# check if w is an integer
|
||||||
|
if not w_candidate.is_integer():
|
||||||
|
h_candidate = hw / w
|
||||||
|
# check if h is an integer
|
||||||
|
if not h_candidate.is_integer():
|
||||||
|
return iterative_closest_divisors(hw, aspect_ratio)
|
||||||
|
else:
|
||||||
|
h = int(h_candidate)
|
||||||
|
else:
|
||||||
|
w = int(w_candidate)
|
||||||
|
return h, w
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def split_attention(
|
||||||
|
layer: nn.Module,
|
||||||
|
/,
|
||||||
|
aspect_ratio: float, # width/height
|
||||||
|
tile_size: int = 128, # 128 for VAE
|
||||||
|
swap_size: int = 1, # 1 for VAE
|
||||||
|
*,
|
||||||
|
disable: bool = False,
|
||||||
|
max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1
|
||||||
|
scale_depth: bool = True, # scale the tile-size depending on the depth
|
||||||
|
is_sdxl: bool = False, # is the model SD-XL
|
||||||
|
):
|
||||||
|
# Hijacks AttnBlock from ldm and Attention from diffusers
|
||||||
|
|
||||||
|
if disable:
|
||||||
|
logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
latent_tile_size = max(128, tile_size) // 8
|
||||||
|
|
||||||
|
def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
|
||||||
|
@wraps(forward)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
x = args[0]
|
||||||
|
|
||||||
|
# VAE
|
||||||
|
if x.ndim == 4:
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
nh = random_divisor(h, latent_tile_size, swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size, swap_size)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
||||||
|
|
||||||
|
out = forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
||||||
|
|
||||||
|
# U-Net
|
||||||
|
else:
|
||||||
|
hw: int = x.size(1)
|
||||||
|
h, w = find_hw_candidates(hw, aspect_ratio)
|
||||||
|
assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
||||||
|
|
||||||
|
factor = 2**depth if scale_depth else 1
|
||||||
|
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||||
|
|
||||||
|
module._split_sizes_hypertile.append((nh, nw)) # type: ignore
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
|
||||||
|
out = forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# Handle hijacking the forward method and recovering afterwards
|
||||||
|
try:
|
||||||
|
if is_sdxl:
|
||||||
|
layers = DEPTH_LAYERS_XL
|
||||||
|
else:
|
||||||
|
layers = DEPTH_LAYERS
|
||||||
|
for depth in range(max_depth + 1):
|
||||||
|
for layer_name, module in layer.named_modules():
|
||||||
|
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
||||||
|
# print input shape for debugging
|
||||||
|
logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
|
||||||
|
# hijack
|
||||||
|
module._original_forward_hypertile = module.forward
|
||||||
|
module.forward = self_attn_forward(module.forward, depth, layer_name, module)
|
||||||
|
module._split_sizes_hypertile = []
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for layer_name, module in layer.named_modules():
|
||||||
|
# remove hijack
|
||||||
|
if hasattr(module, "_original_forward_hypertile"):
|
||||||
|
if module._split_sizes_hypertile:
|
||||||
|
logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
|
||||||
|
# recover
|
||||||
|
module.forward = module._original_forward_hypertile
|
||||||
|
del module._original_forward_hypertile
|
||||||
|
del module._split_sizes_hypertile
|
|
@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.paths as paths
|
import modules.paths as paths
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
|
from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available
|
||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
|
@ -799,17 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
unet_object = p.sd_model.model
|
|
||||||
vae_model = p.sd_model.first_stage_model
|
|
||||||
try:
|
|
||||||
from hyper_tile import split_attention, flush
|
|
||||||
except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df
|
|
||||||
split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager
|
|
||||||
flush = lambda: None
|
|
||||||
import random
|
|
||||||
saved_rng_state = random.getstate()
|
|
||||||
random.seed(p.seed) # hyper_tile uses random, so we need to seed it
|
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
|
@ -871,29 +861,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
p.comment(comment)
|
p.comment(comment)
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
set_hypertile_seed(p.seed)
|
||||||
|
# add batch size + hypertile status to information to reproduce the run
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
# get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
gcd = math.gcd(p.width, p.height)
|
|
||||||
largest_tile_size_available = 1
|
|
||||||
while gcd % (largest_tile_size_available * 2) == 0:
|
|
||||||
largest_tile_size_available *= 2
|
|
||||||
aspect_ratio = p.width / p.height
|
|
||||||
with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn):
|
|
||||||
with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn):
|
|
||||||
flush()
|
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
|
||||||
|
|
||||||
if getattr(samples_ddim, 'already_decoded', False):
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
x_samples_ddim = samples_ddim
|
x_samples_ddim = samples_ddim
|
||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn):
|
with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
flush()
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
|
@ -1000,7 +981,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
random.setstate(saved_rng_state)
|
|
||||||
if not p.disable_extra_networks and p.extra_network_data:
|
if not p.disable_extra_networks and p.extra_network_data:
|
||||||
extra_networks.deactivate(p, p.extra_network_data)
|
extra_networks.deactivate(p, p.extra_network_data)
|
||||||
|
|
||||||
|
@ -1161,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
aspect_ratio = self.width / self.height
|
||||||
x = self.rng.next()
|
x = self.rng.next()
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
tile_size = largest_tile_size_available(self.width, self.height)
|
||||||
|
with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
devices.torch_gc()
|
||||||
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
del x
|
del x
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
decoded_samples = None
|
decoded_samples = None
|
||||||
|
|
||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
|
@ -1186,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
|
@ -1264,18 +1244,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
if self.scripts is not None:
|
if self.scripts is not None:
|
||||||
self.scripts.before_hr(self)
|
self.scripts.before_hr(self)
|
||||||
|
tile_size = largest_tile_size_available(target_width, target_height)
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return decoded_samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -1550,8 +1531,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
if self.initial_noise_multiplier != 1.0:
|
if self.initial_noise_multiplier != 1.0:
|
||||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||||
x *= self.initial_noise_multiplier
|
x *= self.initial_noise_multiplier
|
||||||
|
aspect_ratio = self.width / self.height
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
tile_size = largest_tile_size_available(self.width, self.height)
|
||||||
|
with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
|
||||||
|
devices.torch_gc()
|
||||||
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
Loading…
Reference in New Issue