From a8a58dbac7b205ae90664c3b249d60e4baa2855c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 5 Sep 2022 03:25:37 +0300 Subject: [PATCH] re-integrated tiling option as a UI element --- modules/img2img.py | 3 ++- modules/processing.py | 6 +++++- modules/sd_hijack.py | 20 ++++++++++++++++++++ modules/shared.py | 2 -- modules/txt2img.py | 5 +++-- modules/ui.py | 4 ++++ webui.py | 5 ----- 7 files changed, 34 insertions(+), 11 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index b1ef13267..e6707f960 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import modules.images as images import modules.scripts -def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): +def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): is_inpaint = mode == 1 is_loopback = mode == 2 is_upscale = mode == 3 @@ -37,6 +37,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index width=width, height=height, use_GFPGAN=use_GFPGAN, + tiling=tiling, init_images=[image], mask=mask, mask_blur=mask_blur, diff --git a/modules/processing.py b/modules/processing.py index adc5d851b..a5b2afb97 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,6 +9,7 @@ import numpy as np from PIL import Image, ImageFilter, ImageOps import random +import modules.sd_hijack from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -28,7 +29,7 @@ def torch_gc(): class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -44,6 +45,7 @@ class StableDiffusionProcessing: self.width: int = width self.height: int = height self.use_GFPGAN: bool = use_GFPGAN + self.tiling: bool = tiling self.do_not_save_samples: bool = do_not_save_samples self.do_not_save_grid: bool = do_not_save_grid self.extra_generation_params: dict = extra_generation_params @@ -110,6 +112,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_grids, exist_ok=True) + modules.sd_hijack.model_hijack.apply_circular(p.tiling) + comments = [] if type(prompt) == list: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9779c30cc..2d26b5f71 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -49,6 +49,8 @@ class StableDiffusionModelHijack: fixes = None comments = [] dir_mtime = None + layers = None + circular_enabled = False def load_textual_inversion_embeddings(self, dirname, model): mt = os.path.getmtime(dirname) @@ -105,6 +107,24 @@ class StableDiffusionModelHijack: if cmd_opts.opt_split_attention: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward + def flatten(el): + flattened = [flatten(children) for children in el.children()] + res = [el] + for c in flattened: + res += c + return res + + self.layers = flatten(m) + + def apply_circular(self, enable): + if self.circular_enabled == enable: + return + + self.circular_enabled = enable + + for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: + layer.padding_mode = 'circular' if enable else 'zeros' + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): diff --git a/modules/shared.py b/modules/shared.py index 0722185d6..9e744f6ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -30,8 +30,6 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision", parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) parser.add_argument("--opt-split-attention", action='store_true', help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance") -parser.add_argument("--tiling", action='store_true', help="causes the model to generate images that can be tiled") - cmd_opts = parser.parse_args() cpu = torch.device("cpu") diff --git a/modules/txt2img.py b/modules/txt2img.py index fb65a7f6d..dfce49fff 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args): +def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -21,7 +21,8 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u cfg_scale=cfg_scale, width=width, height=height, - use_GFPGAN=use_GFPGAN + use_GFPGAN=use_GFPGAN, + tiling=tiling, ) processed = modules.scripts.scripts_txt2img.run(p, *args) diff --git a/modules/ui.py b/modules/ui.py index 4119369e9..a2f1124ef 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -155,6 +155,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Row(): use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) + tiling = gr.Checkbox(label='Tiling', value=False) with gr.Row(): batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) @@ -195,6 +196,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): steps, sampler_index, use_gfpgan, + tiling, batch_count, batch_size, cfg_scale, @@ -256,6 +258,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Row(): use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) + tiling = gr.Checkbox(label='Tiling', value=False) sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) with gr.Row(): @@ -339,6 +342,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): mask_blur, inpainting_fill, use_gfpgan, + tiling, switch_mode, batch_count, batch_size, diff --git a/webui.py b/webui.py index 6f4834821..dbc9dd541 100644 --- a/webui.py +++ b/webui.py @@ -140,11 +140,6 @@ try: except Exception: pass - -if cmd_opts.tiling: - # this has to be done before the model is loaded - modules.sd_hijack.add_circular_option_to_conv_2d() - sd_config = OmegaConf.load(cmd_opts.config) shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())