Merge branch 'hypertile-in-sample' into dev
This commit is contained in:
commit
f85b74763d
|
@ -174,5 +174,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
- LyCORIS - KohakuBlueleaf
|
- LyCORIS - KohakuBlueleaf
|
||||||
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||||
|
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
"""
|
"""
|
||||||
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
||||||
Warn : The patch works well only if the input image has a width and height that are multiples of 128
|
Warn: The patch works well only if the input image has a width and height that are multiples of 128
|
||||||
Author : @tfernd Github : https://github.com/tfernd/HyperTile
|
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
@ -18,6 +21,19 @@ import random
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HypertileParams:
|
||||||
|
depth = 0
|
||||||
|
layer_name = ""
|
||||||
|
tile_size: int = 0
|
||||||
|
swap_size: int = 0
|
||||||
|
aspect_ratio: float = 1.0
|
||||||
|
forward = None
|
||||||
|
enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO add SD-XL layers
|
# TODO add SD-XL layers
|
||||||
DEPTH_LAYERS = {
|
DEPTH_LAYERS = {
|
||||||
0: [
|
0: [
|
||||||
|
@ -176,6 +192,7 @@ DEPTH_LAYERS_XL = {
|
||||||
|
|
||||||
RNG_INSTANCE = random.Random()
|
RNG_INSTANCE = random.Random()
|
||||||
|
|
||||||
|
|
||||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
"""
|
"""
|
||||||
Returns a random divisor of value that
|
Returns a random divisor of value that
|
||||||
|
@ -193,10 +210,13 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
|
|
||||||
return ns[idx]
|
return ns[idx]
|
||||||
|
|
||||||
|
|
||||||
def set_hypertile_seed(seed: int) -> None:
|
def set_hypertile_seed(seed: int) -> None:
|
||||||
RNG_INSTANCE.seed(seed)
|
RNG_INSTANCE.seed(seed)
|
||||||
|
|
||||||
def largest_tile_size_available(width:int, height:int) -> int:
|
|
||||||
|
@functools.cache
|
||||||
|
def largest_tile_size_available(width: int, height: int) -> int:
|
||||||
"""
|
"""
|
||||||
Calculates the largest tile size available for a given width and height
|
Calculates the largest tile size available for a given width and height
|
||||||
Tile size is always a power of 2
|
Tile size is always a power of 2
|
||||||
|
@ -207,6 +227,7 @@ def largest_tile_size_available(width:int, height:int) -> int:
|
||||||
largest_tile_size_available *= 2
|
largest_tile_size_available *= 2
|
||||||
return largest_tile_size_available
|
return largest_tile_size_available
|
||||||
|
|
||||||
|
|
||||||
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
@ -219,6 +240,7 @@ def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
||||||
return closest_pair
|
return closest_pair
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
|
@ -240,44 +262,28 @@ def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
w = int(w_candidate)
|
w = int(w_candidate)
|
||||||
return h, w
|
return h, w
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def split_attention(
|
|
||||||
layer: nn.Module,
|
|
||||||
/,
|
|
||||||
aspect_ratio: float, # width/height
|
|
||||||
tile_size: int = 128, # 128 for VAE
|
|
||||||
swap_size: int = 1, # 1 for VAE
|
|
||||||
*,
|
|
||||||
disable: bool = False,
|
|
||||||
max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1
|
|
||||||
scale_depth: bool = True, # scale the tile-size depending on the depth
|
|
||||||
is_sdxl: bool = False, # is the model SD-XL
|
|
||||||
):
|
|
||||||
# Hijacks AttnBlock from ldm and Attention from diffusers
|
|
||||||
|
|
||||||
if disable:
|
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
|
||||||
logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
latent_tile_size = max(128, tile_size) // 8
|
@wraps(params.forward)
|
||||||
|
|
||||||
def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
|
|
||||||
@wraps(forward)
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
if not params.enabled:
|
||||||
|
return params.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
latent_tile_size = max(128, params.tile_size) // 8
|
||||||
x = args[0]
|
x = args[0]
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
if x.ndim == 4:
|
if x.ndim == 4:
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
nh = random_divisor(h, latent_tile_size, swap_size)
|
nh = random_divisor(h, latent_tile_size, params.swap_size)
|
||||||
nw = random_divisor(w, latent_tile_size, swap_size)
|
nw = random_divisor(w, latent_tile_size, params.swap_size)
|
||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
||||||
|
|
||||||
out = forward(x, *args[1:], **kwargs)
|
out = params.forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
||||||
|
@ -285,19 +291,17 @@ def split_attention(
|
||||||
# U-Net
|
# U-Net
|
||||||
else:
|
else:
|
||||||
hw: int = x.size(1)
|
hw: int = x.size(1)
|
||||||
h, w = find_hw_candidates(hw, aspect_ratio)
|
h, w = find_hw_candidates(hw, params.aspect_ratio)
|
||||||
assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
||||||
|
|
||||||
factor = 2**depth if scale_depth else 1
|
factor = 2 ** params.depth if scale_depth else 1
|
||||||
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
|
||||||
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
|
||||||
|
|
||||||
module._split_sizes_hypertile.append((nh, nw)) # type: ignore
|
|
||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
|
||||||
out = forward(x, *args[1:], **kwargs)
|
out = params.forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
|
@ -307,65 +311,38 @@ def split_attention(
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
# Handle hijacking the forward method and recovering afterwards
|
|
||||||
try:
|
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
|
||||||
if is_sdxl:
|
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
|
||||||
layers = DEPTH_LAYERS_XL
|
if hypertile_layers is None:
|
||||||
else:
|
if not enable:
|
||||||
layers = DEPTH_LAYERS
|
return
|
||||||
for depth in range(max_depth + 1):
|
|
||||||
for layer_name, module in layer.named_modules():
|
hypertile_layers = {}
|
||||||
|
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
|
||||||
|
|
||||||
|
for depth in range(4):
|
||||||
|
for layer_name, module in model.named_modules():
|
||||||
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
||||||
# print input shape for debugging
|
params = HypertileParams()
|
||||||
logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
|
module.__webui_hypertile_params = params
|
||||||
# hijack
|
params.forward = module.forward
|
||||||
module._original_forward_hypertile = module.forward
|
params.depth = depth
|
||||||
module.forward = self_attn_forward(module.forward, depth, layer_name, module)
|
params.layer_name = layer_name
|
||||||
module._split_sizes_hypertile = []
|
module.forward = self_attn_forward(params)
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
for layer_name, module in layer.named_modules():
|
|
||||||
# remove hijack
|
|
||||||
if hasattr(module, "_original_forward_hypertile"):
|
|
||||||
if module._split_sizes_hypertile:
|
|
||||||
logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
|
|
||||||
# recover
|
|
||||||
module.forward = module._original_forward_hypertile
|
|
||||||
del module._original_forward_hypertile
|
|
||||||
del module._split_sizes_hypertile
|
|
||||||
|
|
||||||
def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
|
hypertile_layers[layer_name] = 1
|
||||||
"""
|
|
||||||
Returns context manager for VAE
|
|
||||||
"""
|
|
||||||
enabled = opts.hypertile_split_vae_attn
|
|
||||||
swap_size = opts.hypertile_swap_size_vae
|
|
||||||
max_depth = opts.hypertile_max_depth_vae
|
|
||||||
tile_size_max = opts.hypertile_max_tile_vae
|
|
||||||
return split_attention(
|
|
||||||
model,
|
|
||||||
aspect_ratio=aspect_ratio,
|
|
||||||
tile_size=min(tile_size, tile_size_max),
|
|
||||||
swap_size=swap_size,
|
|
||||||
disable=not enabled,
|
|
||||||
max_depth=max_depth,
|
|
||||||
is_sdxl=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
|
model.__webui_hypertile_layers = hypertile_layers
|
||||||
"""
|
|
||||||
Returns context manager for U-Net
|
aspect_ratio = width / height
|
||||||
"""
|
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
|
||||||
enabled = opts.hypertile_split_unet_attn
|
|
||||||
swap_size = opts.hypertile_swap_size_unet
|
for layer_name, module in model.named_modules():
|
||||||
max_depth = opts.hypertile_max_depth_unet
|
if layer_name in hypertile_layers:
|
||||||
tile_size_max = opts.hypertile_max_tile_unet
|
params = module.__webui_hypertile_params
|
||||||
return split_attention(
|
|
||||||
model,
|
params.tile_size = tile_size
|
||||||
aspect_ratio=aspect_ratio,
|
params.swap_size = swap_size
|
||||||
tile_size=min(tile_size, tile_size_max),
|
params.aspect_ratio = aspect_ratio
|
||||||
swap_size=swap_size,
|
params.enabled = enable and params.depth <= max_depth
|
||||||
disable=not enabled,
|
|
||||||
max_depth=max_depth,
|
|
||||||
is_sdxl=is_sdxl,
|
|
||||||
)
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import hypertile
|
||||||
|
from modules import scripts, script_callbacks, shared
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptHypertile(scripts.Script):
|
||||||
|
name = "Hypertile"
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def process(self, p, *args):
|
||||||
|
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||||
|
|
||||||
|
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
||||||
|
|
||||||
|
def before_hr(self, p, *args):
|
||||||
|
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_hypertile(width, height, enable_unet=True):
|
||||||
|
hypertile.hypertile_hook_model(
|
||||||
|
shared.sd_model.first_stage_model,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
swap_size=shared.opts.hypertile_swap_size_vae,
|
||||||
|
max_depth=shared.opts.hypertile_max_depth_vae,
|
||||||
|
tile_size_max=shared.opts.hypertile_max_tile_vae,
|
||||||
|
enable=shared.opts.hypertile_enable_vae,
|
||||||
|
)
|
||||||
|
|
||||||
|
hypertile.hypertile_hook_model(
|
||||||
|
shared.sd_model.model,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
swap_size=shared.opts.hypertile_swap_size_unet,
|
||||||
|
max_depth=shared.opts.hypertile_max_depth_unet,
|
||||||
|
tile_size_max=shared.opts.hypertile_max_tile_unet,
|
||||||
|
enable=enable_unet,
|
||||||
|
is_sdxl=shared.sd_model.is_sdxl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"hypertile_explanation": shared.OptionHTML("""
|
||||||
|
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
|
||||||
|
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
|
||||||
|
benefit.
|
||||||
|
"""),
|
||||||
|
|
||||||
|
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
|
||||||
|
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
|
||||||
|
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||||
|
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
|
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
|
||||||
|
|
||||||
|
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
|
||||||
|
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||||
|
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
|
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, opt in options.items():
|
||||||
|
opt.section = ('hypertile', "Hypertile")
|
||||||
|
shared.opts.add_option(name, opt)
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
|
@ -24,7 +24,6 @@ from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.paths as paths
|
import modules.paths as paths
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae
|
|
||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
|
@ -861,8 +860,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
p.comment(comment)
|
p.comment(comment)
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
set_hypertile_seed(p.seed)
|
|
||||||
# add batch size + hypertile status to information to reproduce the run
|
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
|
@ -874,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts):
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
|
@ -1141,25 +1138,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
aspect_ratio = self.width / self.height
|
|
||||||
x = self.rng.next()
|
x = self.rng.next()
|
||||||
tile_size = largest_tile_size_available(self.width, self.height)
|
|
||||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
|
||||||
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
|
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
del x
|
del x
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
|
||||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
decoded_samples = None
|
decoded_samples = None
|
||||||
|
|
||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
|
@ -1244,17 +1239,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
if self.scripts is not None:
|
if self.scripts is not None:
|
||||||
self.scripts.before_hr(self)
|
self.scripts.before_hr(self)
|
||||||
tile_size = largest_tile_size_available(target_width, target_height)
|
|
||||||
aspect_ratio = self.width / self.height
|
|
||||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
|
||||||
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
|
||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
@ -1532,10 +1524,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
if self.initial_noise_multiplier != 1.0:
|
if self.initial_noise_multiplier != 1.0:
|
||||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||||
x *= self.initial_noise_multiplier
|
x *= self.initial_noise_multiplier
|
||||||
aspect_ratio = self.width / self.height
|
|
||||||
tile_size = largest_tile_size_available(self.width, self.height)
|
|
||||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
|
||||||
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
|
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
|
|
|
@ -201,14 +201,6 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||||
"hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
|
|
||||||
"hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
|
|
||||||
"hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
|
||||||
"hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
|
||||||
"hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
|
|
||||||
"hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
|
|
||||||
"hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
|
||||||
"hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
|
|
Loading…
Reference in New Issue