diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index bbcbbe7d6..eda42fa7d 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -2,7 +2,7 @@ name: Feature request about: Suggest an idea for this project title: '' -labels: '' +labels: 'suggestion' assignees: '' --- diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..935fedcf2 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @AUTOMATIC1111 diff --git a/README.md b/README.md index 561eb03d1..859a91b6f 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - CodeFormer, face restoration tool as an alternative to GFPGAN - RealESRGAN, neural network upscaler - ESRGAN, neural network upscaler with a lot of third party models - - SwinIR, neural network upscaler + - SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers - LDSR, Latent diffusion super resolution upscaling - Resizing aspect ratio options - Sampling method selection + - Adjust sampler eta values (noise multiplier) + - More advanced noise setting options - Interrupt processing at any time - 4GB video card support (also reports of 2GB working) - Correct seeds for batches @@ -67,6 +69,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) +- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. @@ -116,13 +119,17 @@ The documentation was moved from this README over to the project's [wiki](https: - CodeFormer - https://github.com/sczhou/CodeFormer - ESRGAN - https://github.com/xinntao/ESRGAN - SwinIR - https://github.com/JingyunLiang/SwinIR +- Swin2SR - https://github.com/mv-lab/swin2sr - LDSR - https://github.com/Hafiidz/latent-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. +- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) - Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator +- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch +- xformers - https://github.com/facebookresearch/xformers +- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. -- DeepDanbooru - interrogator for anime diffusors https://github.com/KichangKim/DeepDanbooru - (You) diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 2d82269fc..7636c4b33 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -16,7 +16,7 @@ contextMenuInit = function(){ oldMenu.remove() } - let tabButton = gradioApp().querySelector('button') + let tabButton = uiCurrentTab let baseStyle = window.getComputedStyle(tabButton) const contextMenu = document.createElement('nav') @@ -123,48 +123,53 @@ contextMenuInit = function(){ return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] } -initResponse = contextMenuInit() -appendContextMenuOption = initResponse[0] -removeContextMenuOption = initResponse[1] -addContextMenuEventListener = initResponse[2] +initResponse = contextMenuInit(); +appendContextMenuOption = initResponse[0]; +removeContextMenuOption = initResponse[1]; +addContextMenuEventListener = initResponse[2]; - -//Start example Context Menu Items -generateOnRepeatId = appendContextMenuOption('#txt2img_generate','Generate forever',function(){ - let genbutton = gradioApp().querySelector('#txt2img_generate'); - let interruptbutton = gradioApp().querySelector('#txt2img_interrupt'); - if(!interruptbutton.offsetParent){ - genbutton.click(); - } - clearInterval(window.generateOnRepeatInterval) - window.generateOnRepeatInterval = setInterval(function(){ +(function(){ + //Start example Context Menu Items + let generateOnRepeat = function(genbuttonid,interruptbuttonid){ + let genbutton = gradioApp().querySelector(genbuttonid); + let interruptbutton = gradioApp().querySelector(interruptbuttonid); if(!interruptbutton.offsetParent){ genbutton.click(); } - }, - 500)} -) - -cancelGenerateForever = function(){ - clearInterval(window.generateOnRepeatInterval) - let interruptbutton = gradioApp().querySelector('#txt2img_interrupt'); - if(interruptbutton.offsetParent){ - interruptbutton.click(); + clearInterval(window.generateOnRepeatInterval) + window.generateOnRepeatInterval = setInterval(function(){ + if(!interruptbutton.offsetParent){ + genbutton.click(); + } + }, + 500) } -} -appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) -appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#txt2img_generate','Generate forever',function(){ + generateOnRepeat('#txt2img_generate','#txt2img_interrupt'); + }) + appendContextMenuOption('#img2img_generate','Generate forever',function(){ + generateOnRepeat('#img2img_generate','#img2img_interrupt'); + }) - -appendContextMenuOption('#roll','Roll three', - function(){ - let rollbutton = gradioApp().querySelector('#roll'); - setTimeout(function(){rollbutton.click()},100) - setTimeout(function(){rollbutton.click()},200) - setTimeout(function(){rollbutton.click()},300) + let cancelGenerateForever = function(){ + clearInterval(window.generateOnRepeatInterval) } -) + + appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) + + appendContextMenuOption('#roll','Roll three', + function(){ + let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); + setTimeout(function(){rollbutton.click()},100) + setTimeout(function(){rollbutton.click()},200) + setTimeout(function(){rollbutton.click()},300) + } + ) +})(); //End example Context Menu Items onUiUpdate(function(){ diff --git a/javascript/hints.js b/javascript/hints.js index 8e352e94a..045f2d3c1 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -79,6 +79,8 @@ titles = { "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", + "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", + "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.", } diff --git a/launch.py b/launch.py index f42f557de..16627a032 100644 --- a/launch.py +++ b/launch.py @@ -104,6 +104,7 @@ def prepare_enviroment(): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') xformers = '--xformers' in args deepdanbooru = '--deepdanbooru' in args + ngrok = '--ngrok' in args try: commit = run(f"{git} rev-parse HEAD").strip() @@ -127,13 +128,16 @@ def prepare_enviroment(): if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if platform.system() == "Windows": - run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") + run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") elif platform.system() == "Linux": run_pip("install xformers", "xformers") if not is_installed("deepdanbooru") and deepdanbooru: run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") + if not is_installed("pyngrok") and ngrok: + run_pip("install pyngrok", "ngrok") + os.makedirs(dir_repos, exist_ok=True) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py deleted file mode 100644 index 498bc9d8f..000000000 --- a/modules/hypernetwork.py +++ /dev/null @@ -1,98 +0,0 @@ -import glob -import os -import sys -import traceback - -import torch - -from ldm.util import default -from modules import devices, shared -import torch -from torch import einsum -from einops import rearrange, repeat - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - self.load_state_dict(state_dict, strict=True) - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, filename): - self.filename = filename - self.name = os.path.splitext(os.path.basename(filename))[0] - self.layers = {} - - state_dict = torch.load(filename, map_location='cpu') - for size, sd in state_dict.items(): - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork(path) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py new file mode 100644 index 000000000..5608e7995 --- /dev/null +++ b/modules/hypernetworks/hypernetwork.py @@ -0,0 +1,294 @@ +import datetime +import glob +import html +import os +import sys +import traceback +import tqdm + +import torch + +from ldm.util import default +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict=None): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + if state_dict is not None: + self.load_state_dict(state_dict, strict=True) + else: + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() + + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, name=None, enable_sizes=None): + self.filename = None + self.name = name + self.layers = {} + self.step = 0 + self.sd_checkpoint = None + self.sd_checkpoint_name = None + + for size in enable_sizes or []: + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + + def weights(self): + res = [] + + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + + return res + + def save(self, filename): + state_dict = {} + + for k, v in self.layers.items(): + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) + + state_dict['step'] = self.step + state_dict['name'] = self.name + state_dict['sd_checkpoint'] = self.sd_checkpoint + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + + torch.save(state_dict, filename) + + def load(self, filename): + self.filename = filename + if self.name is None: + self.name = os.path.splitext(os.path.basename(filename))[0] + + state_dict = torch.load(filename, map_location='cpu') + + for size, sd in state_dict.items(): + if type(size) == int: + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + self.name = state_dict.get('name', self.name) + self.step = state_dict.get('step', 0) + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + + +def list_hypernetworks(path): + res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") + try: + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + + shared.loaded_hypernetwork = None + + +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + k = self.to_k(context_k) + v = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): + assert hypernetwork_name, 'embedding not selected' + + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + unload = shared.opts.unload_models_when_training + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + return hypernetwork, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + for i, (x, text, cond) in pbar: + hypernetwork.step = i + ititial_step + + if hypernetwork.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + cond = cond.to(devices.device) + x = x.to(devices.device) + loss = shared.sd_model(x.unsqueeze(0), cond)[0] + del x + del cond + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + hypernetwork.save(last_saved_file) + + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + + preview_text = text if preview_image_prompt == "" else preview_image_prompt + + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=preview_text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + checkpoint = sd_models.select_checkpoint() + + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.save(filename) + + return hypernetwork, filename + + diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py new file mode 100644 index 000000000..c67facbb7 --- /dev/null +++ b/modules/hypernetworks/ui.py @@ -0,0 +1,47 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared, devices +from modules.hypernetworks import hypernetwork + + +def create_hypernetwork(name, enable_sizes): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes]) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + sd_hijack.apply_optimizations() + diff --git a/modules/ngrok.py b/modules/ngrok.py new file mode 100644 index 000000000..7d03a6df5 --- /dev/null +++ b/modules/ngrok.py @@ -0,0 +1,15 @@ +from pyngrok import ngrok, conf, exception + + +def connect(token, port): + if token == None: + token = 'None' + conf.get_default().auth_token = token + try: + public_url = ngrok.connect(port).public_url + except exception.PyngrokNgrokError: + print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' + f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') + else: + print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' + 'You can use this link after the launch is complete.') diff --git a/modules/processing.py b/modules/processing.py index 50ba4fc5f..698b3069e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: + if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) + if opts.eta_noise_seed_delta > 0: + torch.manual_seed(seed + opts.eta_noise_seed_delta) + for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -301,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, + "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } generation_params.update(p.extra_generation_params) diff --git a/modules/safe.py b/modules/safe.py index 059174632..20be16a50 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -10,6 +10,7 @@ import torch import numpy import _codecs import zipfile +import re # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage @@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler): raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") +allowed_zip_names = ["archive/data.pkl", "archive/version"] +allowed_zip_names_re = re.compile(r"^archive/data/\d+$") + + +def check_zip_filenames(filename, names): + for name in names: + if name in allowed_zip_names: + continue + if allowed_zip_names_re.match(name): + continue + + raise Exception(f"bad file inside {filename}: {name}") + + def check_pt(filename): try: # new pytorch format is a zip file with zipfile.ZipFile(filename) as z: + check_zip_filenames(filename, z.namelist()) + with z.open('archive/data.pkl') as file: unpickler = RestrictedUnpickler(file) unpickler.load() diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 437acce4c..ac70f8767 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,8 +8,9 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts +from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -23,30 +24,37 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + if not invokeAI_mps_available and shared.device.type == 'mps': + print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print("Applying v1 cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + else: + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization.") + print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward def undo_optimizations(): + from modules.hypernetworks import hypernetwork + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward def get_target_prompt_token_count(token_count): - if token_count < 75: - return 75 - - return math.ceil(token_count / 10) * 10 + return math.ceil(max(token_count, 1) / 75) * 75 class StableDiffusionModelHijack: @@ -110,6 +118,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -127,7 +137,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.token_mults[ident] = mult def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id if opts.enable_emphasis: @@ -140,6 +149,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes = [] remade_tokens = [] multipliers = [] + last_comma = -1 for tokens, (text, weight) in zip(tokenized, parsed): i = 0 @@ -148,13 +158,33 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + if embedding is None: remade_tokens.append(token) multipliers.append(weight) i += 1 else: emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) + iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [id_end] * rem + multipliers += [1.0] * rem + iteration += 1 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) @@ -162,10 +192,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): token_count = len(remade_tokens) prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) + 1 + tokens_to_add = prompt_target_length - len(remade_tokens) - remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add - multipliers = [1.0] + multipliers + [1.0] * tokens_to_add + remade_tokens = remade_tokens + [id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count @@ -260,29 +290,55 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): - - if opts.use_old_emphasis_implementation: + use_old = opts.use_old_emphasis_implementation + if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] - target_token_count = get_target_prompt_token_count(token_count) + 2 + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(device) + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] - position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - - remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] - tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) @@ -290,7 +346,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 634fb4b24..79405525e 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ import math import sys import traceback +import importlib import torch from torch import einsum @@ -9,12 +10,12 @@ from ldm.util import default from einops import rearrange from modules import shared +from modules.hypernetworks import hypernetwork + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: import xformers.ops - import functorch - xformers._is_functorch_available = True shared.xformers_available = True except Exception: print("Cannot import xformers", file=sys.stderr) @@ -28,16 +29,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) - del context, x + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + del context, context_k, context_v, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in @@ -61,22 +56,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): return self.to_out(r2) -# taken from https://github.com/Doggettx/stable-diffusion +# taken from https://github.com/Doggettx/stable-diffusion and modified def split_cross_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) k_in *= self.scale @@ -128,18 +117,111 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) + +def check_for_psutil(): + try: + spec = importlib.util.find_spec('psutil') + return spec is not None + except ModuleNotFoundError: + return False + +invokeAI_mps_available = check_for_psutil() + +# -- Taken from https://github.com/invoke-ai/InvokeAI -- +if invokeAI_mps_available: + import psutil + mem_total_gb = psutil.virtual_memory().total // (1 << 30) + +def einsum_op_compvis(q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + +def einsum_op_slice_0(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + +def einsum_op_slice_1(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) + return r + +def einsum_op_mps_v1(q, k, v): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + return einsum_op_slice_1(q, k, v, slice_size) + +def einsum_op_mps_v2(q, k, v): + if mem_total_gb > 8 and q.shape[1] <= 4096: + return einsum_op_compvis(q, k, v) + else: + return einsum_op_slice_0(q, k, v, 1) + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + +def einsum_op_cuda(q, k, v): + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + +def einsum_op(q, k, v): + if q.device.type == 'cuda': + return einsum_op_cuda(q, k, v) + + if q.device.type == 'mps': + if mem_total_gb >= 32: + return einsum_op_mps_v1(q, k, v) + return einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32) + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) * self.scale + v = self.to_v(context_v) + del context, context_k, context_v, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + +# -- End of code from https://github.com/invoke-ai/InvokeAI -- + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2cdcd84f3..0a55b4c32 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -152,6 +152,10 @@ def load_model_weights(model, checkpoint_info): devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: + vae_file = shared.cmd_opts.vae_path + if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d168b938f..20309e06b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -57,7 +57,7 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive']) + hidden_img2img = set(opts.hide_samplers + ['PLMS']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] @@ -365,16 +365,26 @@ class KDiffusionSampler: else: sigmas = self.model_wrap.get_sigmas(steps) - noise = noise * sigmas[steps - t_enc - 1] - xi = x + noise - - extra_params_kwargs = self.initialize(p) - sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] + + extra_params_kwargs = self.initialize(p) + if 'sigma_min' in inspect.signature(self.func).parameters: + ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last + extra_params_kwargs['sigma_min'] = sigma_sched[-2] + if 'sigma_max' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigma_max'] = sigma_sched[0] + if 'n' in inspect.signature(self.func).parameters: + extra_params_kwargs['n'] = len(sigma_sched) - 1 + if 'sigma_sched' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigma_sched'] = sigma_sched + if 'sigmas' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigmas'] = sigma_sched self.model_wrap_cfg.init_latent = x - return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) + return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): steps = steps or p.steps diff --git a/modules/shared.py b/modules/shared.py index 5dfc344cc..c1092ff79 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,8 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, hypernetwork +from modules import sd_samplers +from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -36,6 +38,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) @@ -47,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") -parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") -parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) @@ -66,6 +70,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) @@ -81,10 +86,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +def reload_hypernetworks(): + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + + class State: skipped = False interrupted = False @@ -172,6 +184,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { @@ -215,6 +228,10 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), +})) + options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), @@ -225,6 +242,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), @@ -237,6 +255,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), })) options_templates.update(options_section(('ui', "User interface"), { @@ -260,6 +279,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) diff --git a/modules/swinir_model.py b/modules/swinir_model.py index fbd11f843..baa02e3d1 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -10,6 +10,7 @@ from tqdm import tqdm from modules import modelloader from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net +from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData precision_scope = ( @@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler): filename = path if filename is None or not os.path.exists(filename): return None - model = net( + if filename.endswith(".v2.pth"): + model = net2( upscale=scale, in_chans=3, img_size=64, window_size=8, img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler="nearest+conv", - resi_connection="3conv", - ) + resi_connection="1conv", + ) + params = None + else: + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + params = "params_ema" pretrained_model = torch.load(filename) - model.load_state_dict(pretrained_model["params_ema"], strict=True) + if params is not None: + model.load_state_dict(pretrained_model[params], strict=True) + else: + model.load_state_dict(pretrained_model, strict=True) if not cmd_opts.no_half: model = model.half() return model diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py new file mode 100644 index 000000000..0e28ae6ee --- /dev/null +++ b/modules/swinir_model_arch_v2.py @@ -0,0 +1,1017 @@ +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + #assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(Swin2SR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_hf: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] + + else: + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcf772d2f..f61f40d30 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,14 +8,14 @@ from torchvision import transforms import random import tqdm -from modules import devices +from modules import devices, shared import re re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): self.placeholder_token = placeholder_token @@ -32,12 +32,15 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' + cond_model = shared.sd_model.cond_stage_model + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): - image = Image.open(path) - image = image.convert('RGB') - image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + try: + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + except Exception: + continue filename = os.path.basename(path) filename_tokens = os.path.splitext(filename)[0] @@ -52,7 +55,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) - self.dataset.append((init_latent, filename_tokens)) + if include_cond: + text = self.create_text(filename_tokens) + cond = cond_model([text]).to(devices.cpu) + else: + cond = None + + self.dataset.append((init_latent, filename_tokens, cond)) self.length = len(self.dataset) * repeats @@ -63,6 +72,12 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + def create_text(self, filename_tokens): + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + return text + def __len__(self): return self.length @@ -71,10 +86,7 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens = self.dataset[index] + x, filename_tokens, cond = self.dataset[index] - text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) - - return x, text + text = self.create_text(filename_tokens) + return x, text, cond diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d7efdef29..1a6727255 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -46,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) - img = Image.open(filename).convert("RGB") + try: + img = Image.open(filename).convert("RGB") + except Exception: + continue if shared.state.interrupted: break diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c64a45987..47a27faf2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -208,7 +208,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text) in pbar: + for i, (x, text, _) in pbar: embedding.step = i + ititial_step if embedding.step > end_step: @@ -236,10 +236,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward() optimizer.step() - epoch_num = embedding.step // epoch_len - epoch_step = embedding.step - (epoch_num * epoch_len) + 1 + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step - (epoch_num * len(ds)) + 1 - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') @@ -248,12 +248,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + preview_text = text if preview_image_prompt == "" else preview_image_prompt + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=text, + prompt=preview_text, steps=20, - height=training_height, - width=training_width, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) @@ -264,7 +266,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image image.save(last_saved_image) - last_saved_image += f", prompt: {text}" + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index f19ac5e02..70f47343e 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -23,6 +23,8 @@ def preprocess(*args): def train_embedding(*args): + assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + try: sd_hijack.undo_optimizations() diff --git a/modules/ui.py b/modules/ui.py index c9e8355bf..2b688e325 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,6 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui +import modules.hypernetworks.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -50,6 +51,11 @@ if not cmd_opts.share and not cmd_opts.listen: gradio.utils.version_check = lambda: None gradio.utils.get_local_ip_address = lambda: '127.0.0.1' +if cmd_opts.ngrok != None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860) + def gr_show(visible=True): return {"visible": visible, "__type__": "update"} @@ -311,7 +317,7 @@ def interrogate(image): def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image) + prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) return gr_show(True) if prompt is None else prompt @@ -428,7 +434,10 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=8): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) + with gr.Column(scale=1, elem_id="roll_col"): + sh = gr.Button(elem_id="sh", visible=True) with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -524,7 +533,7 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) with gr.Row(): - batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) @@ -549,15 +558,15 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -710,7 +719,7 @@ def create_ui(wrap_gradio_gpu_call): tiling = gr.Checkbox(label='Tiling', value=False) with gr.Row(): - batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) with gr.Group(): @@ -737,15 +746,15 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -961,7 +970,7 @@ def create_ui(wrap_gradio_gpu_call): extras_send_to_inpaint.click( fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_img2img", + _js="extract_image_from_gallery_inpaint", inputs=[result_images], outputs=[init_img_with_mask], ) @@ -1022,7 +1031,20 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary') + + with gr.Group(): + gr.HTML(value="

Create a new hypernetwork

") + + new_hypernetwork_name = gr.Textbox(label="Name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') with gr.Group(): gr.HTML(value="

Preprocess images

") @@ -1047,6 +1069,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -1057,15 +1080,12 @@ def create_ui(wrap_gradio_gpu_call): num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + preview_image_prompt = gr.Textbox(label='Preview prompt', value="") with gr.Row(): - with gr.Column(scale=2): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_embedding = gr.Button(value="Train", variant='primary') + interrupt_training = gr.Button(value="Interrupt") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') + train_embedding = gr.Button(value="Train Embedding", variant='primary') with gr.Column(): progressbar = gr.HTML(elem_id="ti_progressbar") @@ -1091,6 +1111,19 @@ def create_ui(wrap_gradio_gpu_call): ] ) + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + run_preprocess.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", @@ -1124,6 +1157,27 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + preview_image_prompt, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + preview_image_prompt, ], outputs=[ ti_output, @@ -1137,6 +1191,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[], ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1290,6 +1345,7 @@ Requested path was: {f} shared.state.interrupt() settings_interface.gradio_ref.do_restart = True + restart_gradio.click( fn=request_restart, inputs=[], @@ -1331,7 +1387,7 @@ Requested path was: {f} with gr.Tabs() as tabs: for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid): + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() if os.path.exists(os.path.join(script_path, "notification.mp3")): diff --git a/requirements.txt b/requirements.txt index 81641d68f..a0d985ce7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ fairscale==0.4.4 fonts font-roboto gfpgan -gradio==3.4b3 +gradio==3.4.1 invisible-watermark numpy omegaconf @@ -23,4 +23,3 @@ resize-right torchdiffeq kornia lark -functorch diff --git a/requirements_versions.txt b/requirements_versions.txt index fec3e9d5b..2bbea40b4 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -2,7 +2,7 @@ transformers==4.19.2 diffusers==0.3.0 basicsr==1.4.2 gfpgan==1.3.8 -gradio==3.4b3 +gradio==3.4.1 numpy==1.23.3 Pillow==9.2.0 realesrgan==0.3.0 @@ -22,4 +22,3 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 -functorch==0.2.1 diff --git a/script.js b/script.js index cf9896053..9543cbe68 100644 --- a/script.js +++ b/script.js @@ -6,6 +6,10 @@ function get_uiCurrentTab() { return gradioApp().querySelector('.tabs button:not(.border-transparent)') } +function get_uiCurrentTabContent() { + return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])') +} + uiUpdateCallbacks = [] uiTabChangeCallbacks = [] let uiCurrentTab = null @@ -40,6 +44,25 @@ document.addEventListener("DOMContentLoaded", function() { mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) }); +/** + * Add a ctrl+enter as a shortcut to start a generation + */ + document.addEventListener('keydown', function(e) { + var handled = false; + if (e.key !== undefined) { + if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true; + } else if (e.keyCode !== undefined) { + if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true; + } + if (handled) { + button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + if (button) { + button.click(); + } + e.preventDefault(); + } +}) + /** * checks that a UI element is not in another hidden element or tab content */ diff --git a/scripts/loopback.py b/scripts/loopback.py index e90b58d46..d8c68af89 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -38,6 +38,7 @@ class Script(scripts.Script): grids = [] all_images = [] + original_init_image = p.init_images state.job_count = loops * batch_count initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] @@ -45,6 +46,9 @@ class Script(scripts.Script): for n in range(batch_count): history = [] + # Reset to original init image at the start of each batch + p.init_images = original_init_image + for i in range(loops): p.n_iter = 1 p.batch_size = 1 diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 771eb8e49..ef4311054 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,8 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, hypernetwork +from modules import images +from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -27,6 +28,9 @@ def apply_field(field): def apply_prompt(p, x, xs): + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") + p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) @@ -193,7 +197,7 @@ class Script(scripts.Script): x_values = gr.Textbox(label="X values", visible=False, lines=1) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type") + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type") y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) @@ -205,7 +209,10 @@ class Script(scripts.Script): if not no_fixed_seeds: modules.processing.fix_seed(p) - p.batch_size = 1 + if not opts.return_grid: + p.batch_size = 1 + + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): diff --git a/style.css b/style.css index 04bb9576e..e6fa10b4f 100644 --- a/style.css +++ b/style.css @@ -2,6 +2,27 @@ max-width: 100%; } +#txt2img_token_counter { + height: 0px; +} + +#img2img_token_counter { + height: 0px; +} + +#sh{ + min-width: 2em; + min-height: 2em; + max-width: 2em; + max-height: 2em; + flex-grow: 0; + padding-left: 0.25em; + padding-right: 0.25em; + margin: 0.1em 0; + opacity: 0%; + cursor: default; +} + .output-html p {margin: 0 0.5em;} .row > *, @@ -219,6 +240,7 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ position: relative; border: none; + margin-right: 8em; } .gr-panel div.flex-col div.justify-between label span{ @@ -467,3 +489,20 @@ input[type="range"]{ max-width: 32em; padding: 0; } + +canvas[key="mask"] { + z-index: 12 !important; + filter: invert(); + mix-blend-mode: multiply; + pointer-events: none; +} + + +/* gradio 3.4.1 stuff for editable scrollbar values */ +.gr-box > div > div > input.gr-text-input{ + position: absolute; + right: 0.5em; + top: -0.6em; + z-index: 200; + width: 8em; +} diff --git a/textual_inversion_templates/hypernetwork.txt b/textual_inversion_templates/hypernetwork.txt new file mode 100644 index 000000000..91e068905 --- /dev/null +++ b/textual_inversion_templates/hypernetwork.txt @@ -0,0 +1,27 @@ +a photo of a [filewords] +a rendering of a [filewords] +a cropped photo of the [filewords] +the photo of a [filewords] +a photo of a clean [filewords] +a photo of a dirty [filewords] +a dark photo of the [filewords] +a photo of my [filewords] +a photo of the cool [filewords] +a close-up photo of a [filewords] +a bright photo of the [filewords] +a cropped photo of a [filewords] +a photo of the [filewords] +a good photo of the [filewords] +a photo of one [filewords] +a close-up photo of the [filewords] +a rendition of the [filewords] +a photo of the clean [filewords] +a rendition of a [filewords] +a photo of a nice [filewords] +a good photo of a [filewords] +a photo of the nice [filewords] +a photo of the small [filewords] +a photo of the weird [filewords] +a photo of the large [filewords] +a photo of a cool [filewords] +a photo of a small [filewords] diff --git a/textual_inversion_templates/none.txt b/textual_inversion_templates/none.txt new file mode 100644 index 000000000..f77af4612 --- /dev/null +++ b/textual_inversion_templates/none.txt @@ -0,0 +1 @@ +picture diff --git a/webui.py b/webui.py index 270584f77..ca278e940 100644 --- a/webui.py +++ b/webui.py @@ -29,6 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts +import modules.hypernetworks.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() @@ -82,8 +83,7 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) -shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) def webui(): @@ -108,7 +108,7 @@ def webui(): prevent_thread_lock=True ) - app.add_middleware(GZipMiddleware,minimum_size=1000) + app.add_middleware(GZipMiddleware, minimum_size=1000) while 1: time.sleep(0.5) @@ -124,6 +124,8 @@ def webui(): modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) print('Reloading modules: modules.ui') importlib.reload(modules.ui) + print('Refreshing Model List') + modules.sd_models.list_models() print('Restarting Gradio')