re-integrated tiling option as a UI element
This commit is contained in:
parent
f91d0c3d19
commit
a8a58dbac7
|
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
|
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
|
||||||
is_inpaint = mode == 1
|
is_inpaint = mode == 1
|
||||||
is_loopback = mode == 2
|
is_loopback = mode == 2
|
||||||
is_upscale = mode == 3
|
is_upscale = mode == 3
|
||||||
|
@ -37,6 +37,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
use_GFPGAN=use_GFPGAN,
|
use_GFPGAN=use_GFPGAN,
|
||||||
|
tiling=tiling,
|
||||||
init_images=[image],
|
init_images=[image],
|
||||||
mask=mask,
|
mask=mask,
|
||||||
mask_blur=mask_blur,
|
mask_blur=mask_blur,
|
||||||
|
|
|
@ -9,6 +9,7 @@ import numpy as np
|
||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import modules.sd_hijack
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
@ -28,7 +29,7 @@ def torch_gc():
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
|
@ -44,6 +45,7 @@ class StableDiffusionProcessing:
|
||||||
self.width: int = width
|
self.width: int = width
|
||||||
self.height: int = height
|
self.height: int = height
|
||||||
self.use_GFPGAN: bool = use_GFPGAN
|
self.use_GFPGAN: bool = use_GFPGAN
|
||||||
|
self.tiling: bool = tiling
|
||||||
self.do_not_save_samples: bool = do_not_save_samples
|
self.do_not_save_samples: bool = do_not_save_samples
|
||||||
self.do_not_save_grid: bool = do_not_save_grid
|
self.do_not_save_grid: bool = do_not_save_grid
|
||||||
self.extra_generation_params: dict = extra_generation_params
|
self.extra_generation_params: dict = extra_generation_params
|
||||||
|
@ -110,6 +112,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||||
|
|
||||||
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
|
|
||||||
comments = []
|
comments = []
|
||||||
|
|
||||||
if type(prompt) == list:
|
if type(prompt) == list:
|
||||||
|
|
|
@ -49,6 +49,8 @@ class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
comments = []
|
||||||
dir_mtime = None
|
dir_mtime = None
|
||||||
|
layers = None
|
||||||
|
circular_enabled = False
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dirname, model):
|
def load_textual_inversion_embeddings(self, dirname, model):
|
||||||
mt = os.path.getmtime(dirname)
|
mt = os.path.getmtime(dirname)
|
||||||
|
@ -105,6 +107,24 @@ class StableDiffusionModelHijack:
|
||||||
if cmd_opts.opt_split_attention:
|
if cmd_opts.opt_split_attention:
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
|
|
||||||
|
def flatten(el):
|
||||||
|
flattened = [flatten(children) for children in el.children()]
|
||||||
|
res = [el]
|
||||||
|
for c in flattened:
|
||||||
|
res += c
|
||||||
|
return res
|
||||||
|
|
||||||
|
self.layers = flatten(m)
|
||||||
|
|
||||||
|
def apply_circular(self, enable):
|
||||||
|
if self.circular_enabled == enable:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.circular_enabled = enable
|
||||||
|
|
||||||
|
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
||||||
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
|
|
|
@ -30,8 +30,6 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||||
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
|
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance")
|
parser.add_argument("--opt-split-attention", action='store_true', help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance")
|
||||||
parser.add_argument("--tiling", action='store_true', help="causes the model to generate images that can be tiled")
|
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
|
|
|
@ -6,7 +6,7 @@ import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
|
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
|
@ -21,7 +21,8 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
use_GFPGAN=use_GFPGAN
|
use_GFPGAN=use_GFPGAN,
|
||||||
|
tiling=tiling,
|
||||||
)
|
)
|
||||||
|
|
||||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||||
|
|
|
@ -155,6 +155,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
|
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
|
||||||
|
tiling = gr.Checkbox(label='Tiling', value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
||||||
|
@ -195,6 +196,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_index,
|
||||||
use_gfpgan,
|
use_gfpgan,
|
||||||
|
tiling,
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
@ -256,6 +258,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
|
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
|
||||||
|
tiling = gr.Checkbox(label='Tiling', value=False)
|
||||||
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
|
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -339,6 +342,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||||
mask_blur,
|
mask_blur,
|
||||||
inpainting_fill,
|
inpainting_fill,
|
||||||
use_gfpgan,
|
use_gfpgan,
|
||||||
|
tiling,
|
||||||
switch_mode,
|
switch_mode,
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|
5
webui.py
5
webui.py
|
@ -140,11 +140,6 @@ try:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if cmd_opts.tiling:
|
|
||||||
# this has to be done before the model is loaded
|
|
||||||
modules.sd_hijack.add_circular_option_to_conv_2d()
|
|
||||||
|
|
||||||
sd_config = OmegaConf.load(cmd_opts.config)
|
sd_config = OmegaConf.load(cmd_opts.config)
|
||||||
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
||||||
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())
|
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())
|
||||||
|
|
Loading…
Reference in New Issue