From cc92dc1f8d73dd4d574c4c8ccab78b7fc61e440b Mon Sep 17 00:00:00 2001 From: ssysm Date: Sun, 9 Oct 2022 23:17:29 -0400 Subject: [PATCH 01/28] add vae path args --- modules/sd_models.py | 2 +- modules/shared.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b16..b6979432d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") diff --git a/modules/shared.py b/modules/shared.py index 2dc092d68..52ccfa6e3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,7 +64,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) - +parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) cmd_opts = parser.parse_args() From 8acc901ba3a252dc6ab4fabcb41644cf64d1774c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 00:38:55 -0400 Subject: [PATCH 02/28] Newer versions of PyTorch use TypedStorage instead Pytorch 1.13 and later will rename _TypedStorage to TypedStorage, so check for TypedStorage and use _TypedStorage if it is not available. Currently this is needed so that nightly builds of PyTorch work correctly. --- modules/safe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/safe.py b/modules/safe.py index 4d06f2a53..059174632 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -12,6 +12,10 @@ import _codecs import zipfile +# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage + + def encode(*args): out = _codecs.encode(*args) return out @@ -20,7 +24,7 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return torch.storage._TypedStorage() + return TypedStorage() def find_class(self, module, name): if module == 'collections' and name == 'OrderedDict': From a3578233395e585e68c2118d3630cb2a961d4a36 Mon Sep 17 00:00:00 2001 From: Bepis <36346617+bbepis@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:12:29 +1100 Subject: [PATCH 03/28] Add a pull request template --- .../pull_request_template.md | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/pull_request_template.md diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 000000000..86009613e --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,28 @@ +# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! + +If you have a large change, pay special attention to this paragraph: + +> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. + +Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. + +**Describe what this pull request is trying to achieve.** + +A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. + +**Additional notes and description of your changes** + +More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. + +**Environment this was tested in** + +List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. + - OS: [e.g. Windows, Linux] + - Browser [e.g. chrome, safari] + - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] + +**Screenshots or videos of your changes** + +If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. + +This is **required** for anything that touches the user interface. \ No newline at end of file From 7349088d32b080f64058b6e5de5f0380a71ecd09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 16:11:14 +0300 Subject: [PATCH 04/28] --no-half-vae --- modules/devices.py | 6 +++++- modules/processing.py | 11 +++++++++-- modules/sd_models.py | 3 +++ modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 + 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0158b11fc..03ef58f19 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32") device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 +dtype_vae = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. @@ -59,9 +60,12 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) -def autocast(): +def autocast(disable=False): from modules import shared + if disable: + return contextlib.nullcontext() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 94d2dd62c..ec8651ae2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def decode_first_stage(model, x): + with devices.autocast(disable=x.dtype == devices.dtype_vae): + x = model.decode_first_stage(x) + + return x + + def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) @@ -400,7 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: samples_ddim = samples_ddim.to(devices.dtype) - x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) + x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim @@ -533,7 +540,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scale_latent: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") else: - decoded_samples = self.sd_model.decode_first_stage(samples) + decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") diff --git a/modules/sd_models.py b/modules/sd_models.py index e63d3c292..2cdcd84f3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info): model.half() devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): @@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) + model.first_stage_model.to(devices.dtype_vae) + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6e743f7e9..d168b938f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser +from modules import prompt_parser, devices, processing from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None): def sample_to_image(samples): - x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] + x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/modules/shared.py b/modules/shared.py index 1995a99a7..5dfc344cc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") +parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") From 8f1efdc130cf7ff47cb8d3722cdfc0dbeba3069e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 17:03:45 +0300 Subject: [PATCH 05/28] --no-half-vae pt2 --- modules/processing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index ec8651ae2..50ba4fc5f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -405,8 +405,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent - samples_ddim = samples_ddim.to(devices.dtype) - + samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) From ea00c1624bbb0dcb5be07f59c9509061baddf5b1 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 17:07:46 +0900 Subject: [PATCH 06/28] Textual Inversion: Added custom training image size and number of repeats per input image in a single epoch --- modules/textual_inversion/dataset.py | 6 +++--- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 15 ++++++++++++--- modules/ui.py | 8 +++++++- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7c44ea5be..acc4ce597 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,13 +15,13 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token self.size = size - self.width = width - self.height = height + self.width = size + self.height = size self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index f1c002a2b..b3de6fd7e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,8 +7,8 @@ import tqdm from modules import shared, images -def preprocess(process_src, process_dst, process_flip, process_split, process_caption): - size = 512 +def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption): + size = process_size src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f34984..e34dc2e81 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,6 +6,7 @@ import torch import tqdm import html import datetime +import math from modules import shared, devices, sd_hijack, processing, sd_models @@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if ititial_step > steps: return embedding, filename + tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) + epoch_len = (tr_img_len * num_repeats) + tr_img_len + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, (x, text) in pbar: embedding.step = i + ititial_step @@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, loss.backward() optimizer.step() - pbar.set_description(f"loss: {losses.mean():.7f}") + epoch_num = math.floor(embedding.step / epoch_len) + epoch_step = embedding.step - (epoch_num * epoch_len) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') @@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, sd_model=shared.sd_model, prompt=text, steps=20, + height=training_size, + width=training_size, do_not_save_grid=True, do_not_save_samples=True, ) diff --git a/modules/ui.py b/modules/ui.py index 2231a8ed8..f821fd8db 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1029,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') + process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1043,13 +1044,15 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Group(): - gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) + num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, + process_size, process_flip, process_split, process_caption, @@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, + training_size, steps, + num_repeats, create_image_every, save_embedding_every, template_file, From 6ad3a53e368d36535de1a4fca73b3bb78fd40654 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 17:31:33 +0900 Subject: [PATCH 07/28] Fixed progress bar output for epoch --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e34dc2e81..769682ea5 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -228,7 +228,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini optimizer.step() epoch_num = math.floor(embedding.step / epoch_len) - epoch_step = embedding.step - (epoch_num * epoch_len) + epoch_step = embedding.step - (epoch_num * epoch_len) + 1 pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") From 7a20f914eddfdf09c0ccced157ec108205bc3d0f Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 22:35:35 +0900 Subject: [PATCH 08/28] Custom Width and Height --- modules/textual_inversion/dataset.py | 7 +++---- modules/textual_inversion/preprocess.py | 19 ++++++++++--------- .../textual_inversion/textual_inversion.py | 11 +++++------ modules/ui.py | 12 ++++++++---- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index acc4ce597..bcf772d2f 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token - self.size = size - self.width = size - self.height = size + self.width = width + self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index b3de6fd7e..d7efdef29 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,8 +7,9 @@ import tqdm from modules import shared, images -def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption): - size = process_size +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption): + width = process_width + height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) @@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_size, process_flip, process_spl is_wide = ratio < 1 / 1.35 if process_split and is_tall: - img = img.resize((size, size * img.height // img.width)) + img = img.resize((width, height * img.height // img.width)) - top = img.crop((0, 0, size, size)) + top = img.crop((0, 0, width, height)) save_pic(top, index) - bot = img.crop((0, img.height - size, size, img.height)) + bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) elif process_split and is_wide: - img = img.resize((size * img.width // img.height, size)) + img = img.resize((width * img.width // img.height, height)) - left = img.crop((0, 0, size, size)) + left = img.crop((0, 0, width, height)) save_pic(left, index) - right = img.crop((img.width - size, 0, img.width, size)) + right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) else: - img = images.resize_image(1, img, size, size) + img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 769682ea5..5965c5a06 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,7 +6,6 @@ import torch import tqdm import html import datetime -import math from modules import shared, devices, sd_hijack, processing, sd_models @@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward() optimizer.step() - epoch_num = math.floor(embedding.step / epoch_len) + epoch_num = embedding.step // epoch_len epoch_step = embedding.step - (epoch_num * epoch_len) + 1 pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") @@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini sd_model=shared.sd_model, prompt=text, steps=20, - height=training_size, - width=training_size, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) diff --git a/modules/ui.py b/modules/ui.py index f821fd8db..8c06ad7cc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1029,7 +1029,8 @@ def create_ui(wrap_gradio_gpu_call): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') - process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) + process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1050,7 +1051,8 @@ def create_ui(wrap_gradio_gpu_call): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) + training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1095,7 +1097,8 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, - process_size, + process_width, + process_height, process_flip, process_split, process_caption, @@ -1114,7 +1117,8 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, - training_size, + training_width, + training_height, steps, num_repeats, create_image_every, From ce37fdd30e9fc0fe0bc5805a068ce8b11b42b5a3 Mon Sep 17 00:00:00 2001 From: Ben <110583491+TheLastBen@users.noreply.github.com> Date: Sat, 8 Oct 2022 22:03:00 +0100 Subject: [PATCH 09/28] maximize the view --- style.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/style.css b/style.css index c0c3f2bb3..04bb9576e 100644 --- a/style.css +++ b/style.css @@ -1,3 +1,7 @@ +.container { + max-width: 100%; +} + .output-html p {margin: 0 0.5em;} .row > *, From f347ddfd808c56bb1bacdec0c4bedf826ff85cd8 Mon Sep 17 00:00:00 2001 From: RW21 Date: Mon, 10 Oct 2022 10:44:11 +0900 Subject: [PATCH 10/28] Remove max_batch_count from ui.py --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 8c06ad7cc..8ba84911b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -524,7 +524,7 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) with gr.Row(): - batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) @@ -710,7 +710,7 @@ def create_ui(wrap_gradio_gpu_call): tiling = gr.Checkbox(label='Tiling', value=False) with gr.Row(): - batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) with gr.Group(): From b340439586d844e76782149ca1857c8de35773ec Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 05:28:06 +0100 Subject: [PATCH 11/28] Unlimited Token Works Unlimited tokens actually work now. Works with textual inversion too. Replaces the previous not-so-much-working implementation. --- modules/sd_hijack.py | 69 +++++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 437acce4c..8d5c77d89 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -43,10 +43,7 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): - if token_count < 75: - return 75 - - return math.ceil(token_count / 10) * 10 + return math.ceil(max(token_count, 1) / 75) * 75 class StableDiffusionModelHijack: @@ -127,7 +124,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.token_mults[ident] = mult def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id if opts.enable_emphasis: @@ -154,7 +150,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): i += 1 else: emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) + iteration = len(remade_tokens) // 75 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) @@ -162,10 +159,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): token_count = len(remade_tokens) prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) + 1 + tokens_to_add = prompt_target_length - len(remade_tokens) - remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add - multipliers = [1.0] + multipliers + [1.0] * tokens_to_add + remade_tokens = remade_tokens + [id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count @@ -260,29 +257,55 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): - - if opts.use_old_emphasis_implementation: + use_old = opts.use_old_emphasis_implementation + if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] + + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(device) + outputs = self.wrapped.transformer(input_ids=tokens) - target_token_count = get_target_prompt_token_count(token_count) + 2 - - position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] - position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - - remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] - tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) @@ -290,7 +313,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) From 460bbae58726c177beddfcddf351f27e205d3fb2 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 16:09:06 +0100 Subject: [PATCH 12/28] Pad beginning of textual inversion embedding --- modules/sd_hijack.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 8d5c77d89..3a60cd635 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -151,6 +151,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: emb_len = int(embedding.vec.shape[0]) iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [id_end] * rem + multipliers += [1.0] * rem + iteration += 1 fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len From d5c14365fd468dbf89fa12a68bea5b217077273c Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 16:13:47 +0100 Subject: [PATCH 13/28] Add back in output hidden states parameter --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3a60cd635..3edc0e9d8 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -309,7 +309,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] From 6c36fe5719a824fa18f6ad3e02727783f095bc5f Mon Sep 17 00:00:00 2001 From: Melan Date: Mon, 10 Oct 2022 18:16:04 +0200 Subject: [PATCH 14/28] Add ctrl+enter as a shortcut to quickly start a generation. --- script.js | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/script.js b/script.js index cf9896053..a92c0f77d 100644 --- a/script.js +++ b/script.js @@ -40,6 +40,22 @@ document.addEventListener("DOMContentLoaded", function() { mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) }); +/** + * Add a ctrl+enter as a shortcut to start a generation + */ + document.addEventListener('keydown', function(e) { + var handled = false; + if (e.key !== undefined) { + if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true; + } else if (e.keyCode !== undefined) { + if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true; + } + if (handled) { + gradioApp().querySelector("#txt2img_generate").click(); + e.preventDefault(); + } +}) + /** * checks that a UI element is not in another hidden element or tab content */ From 9d33baba587637815d818e5e641d8f8b74c4900d Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 10 Oct 2022 18:46:48 +0300 Subject: [PATCH 15/28] Always show previous mask and fix extras_send dest --- modules/ui.py | 2 +- style.css | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 8ba84911b..e8039d76c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call): extras_send_to_inpaint.click( fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_img2img", + _js="extract_image_from_gallery_inpaint", inputs=[result_images], outputs=[init_img_with_mask], ) diff --git a/style.css b/style.css index 04bb9576e..00a3d07fe 100644 --- a/style.css +++ b/style.css @@ -467,3 +467,10 @@ input[type="range"]{ max-width: 32em; padding: 0; } + +canvas[key="mask"] { + z-index: 12 !important; + filter: invert(); + mix-blend-mode: multiply; + pointer-events: none; +} From b8c38f2bbfa28904f67f0c4f9cabab4d85ebced2 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:44:58 +0300 Subject: [PATCH 16/28] change prebuilt wheel --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index f42f557de..e1000f559 100644 --- a/launch.py +++ b/launch.py @@ -127,7 +127,7 @@ def prepare_enviroment(): if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if platform.system() == "Windows": - run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") + run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") elif platform.system() == "Linux": run_pip("install xformers", "xformers") From 623251ce2b8d152e242011f62984a8247a14a389 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:45:38 +0300 Subject: [PATCH 17/28] allow pascal onwards --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3edc0e9d8..827bf3045 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,7 +23,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward From 3e7a981194ed9c454e951365846e4eba66fa7095 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:51:05 +0300 Subject: [PATCH 18/28] remove functorch --- modules/sd_hijack_optimizations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 634fb4b24..18408e629 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -13,8 +13,6 @@ from modules import shared if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: import xformers.ops - import functorch - xformers._is_functorch_available = True shared.xformers_available = True except Exception: print("Cannot import xformers", file=sys.stderr) From 5c3254b3ee62ef46cb2e3a6ed14182efeb868f30 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:51:41 +0300 Subject: [PATCH 19/28] Update requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 81641d68f..631fe616a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,3 @@ resize-right torchdiffeq kornia lark -functorch From e37d0cdd06772c8d6edb2272c0ef25c46c74cc6d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:51:53 +0300 Subject: [PATCH 20/28] Update requirements_versions.txt --- requirements_versions.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index fec3e9d5b..fdff26878 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,4 +22,3 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 -functorch==0.2.1 From ece27fe98933eb0eda8ea94dc496dd7554f3a08f Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 9 Oct 2022 18:55:33 +0300 Subject: [PATCH 21/28] Add files via upload --- modules/swinir_model_arch_v2.py | 1017 +++++++++++++++++++++++++++++++ 1 file changed, 1017 insertions(+) create mode 100644 modules/swinir_model_arch_v2.py diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py new file mode 100644 index 000000000..0e28ae6ee --- /dev/null +++ b/modules/swinir_model_arch_v2.py @@ -0,0 +1,1017 @@ +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + #assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(Swin2SR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_hf: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] + + else: + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file From ed769977f0d0f201d8e361d365102f18775fc62c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 9 Oct 2022 18:56:59 +0300 Subject: [PATCH 22/28] add swinir v2 support --- modules/swinir_model.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/modules/swinir_model.py b/modules/swinir_model.py index fbd11f843..baa02e3d1 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -10,6 +10,7 @@ from tqdm import tqdm from modules import modelloader from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net +from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData precision_scope = ( @@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler): filename = path if filename is None or not os.path.exists(filename): return None - model = net( + if filename.endswith(".v2.pth"): + model = net2( upscale=scale, in_chans=3, img_size=64, window_size=8, img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler="nearest+conv", - resi_connection="3conv", - ) + resi_connection="1conv", + ) + params = None + else: + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + params = "params_ema" pretrained_model = torch.load(filename) - model.load_state_dict(pretrained_model["params_ema"], strict=True) + if params is not None: + model.load_state_dict(pretrained_model[params], strict=True) + else: + model.load_state_dict(pretrained_model, strict=True) if not cmd_opts.no_half: model = model.half() return model From af62ad4d25dcd0454944368f4925d83101cdedbc Mon Sep 17 00:00:00 2001 From: ssysm Date: Mon, 10 Oct 2022 13:25:28 -0400 Subject: [PATCH 23/28] change vae loading method --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index b0e1d8bda..7a42d924d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -150,9 +150,16 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if os.path.exists(vae_file): + print(f"Found VAE Weights: {vae_file}") + elif shared.cmd_opts.vae_path != None: + vae_file = shared.cmd_opts.vae_path + print(f'No VAE found for inside the model folder. Using CLI specified : {vae_file}') + else: + print("No VAE found for inside the model folder. Passing.") + if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} From 39919c40dd18f5a14ae21403efea1b0f819756c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:32:37 +0300 Subject: [PATCH 24/28] add eta noise seed delta option --- javascript/hints.js | 1 + modules/processing.py | 6 +++++- modules/shared.py | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/javascript/hints.js b/javascript/hints.js index 8e352e94a..47b807764 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -79,6 +79,7 @@ titles = { "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", + "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", } diff --git a/modules/processing.py b/modules/processing.py index 50ba4fc5f..698b3069e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: + if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) + if opts.eta_noise_seed_delta > 0: + torch.manual_seed(seed + opts.eta_noise_seed_delta) + for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -301,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, + "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } generation_params.update(p.extra_generation_params) diff --git a/modules/shared.py b/modules/shared.py index 5dfc344cc..b1c65ecf6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,6 +260,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) From 727e4d108674dc2813507e2a973a733ef21e8d53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:46:55 +0300 Subject: [PATCH 25/28] no to different messages plus fix using != to compare to None --- modules/sd_models.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4c06051e5..0a55b4c32 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -152,15 +152,12 @@ def load_model_weights(model, checkpoint_info): devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - if os.path.exists(vae_file): - print(f"Found VAE Weights: {vae_file}") - elif shared.cmd_opts.vae_path != None: + + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: vae_file = shared.cmd_opts.vae_path - print(f'No VAE found for inside the model folder. Using CLI specified : {vae_file}') - else: - print("No VAE found for inside the model folder. Passing.") if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} From 5da1ba0e91a81804dc911d34c9a2e6956a23199c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 21:24:11 +0300 Subject: [PATCH 26/28] remove batch size restriction from X/Y plot --- scripts/xy_grid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 771eb8e49..42e1489c4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -205,7 +205,10 @@ class Script(scripts.Script): if not no_fixed_seeds: modules.processing.fix_seed(p) - p.batch_size = 1 + if not opts.return_grid: + p.batch_size = 1 + + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): From f98338faa84ecce503e68d8ba13d5f7bbae52730 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 23:15:48 +0300 Subject: [PATCH 27/28] add an option to not add watermark to created images --- javascript/hints.js | 1 + modules/shared.py | 1 + 2 files changed, 2 insertions(+) diff --git a/javascript/hints.js b/javascript/hints.js index 47b807764..045f2d3c1 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -80,6 +80,7 @@ titles = { "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", + "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.", } diff --git a/modules/shared.py b/modules/shared.py index da389f9c0..ecd15ef5d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { From 42bf5fa3256bff5e4640e5a626e750d4e49e01e1 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 21:54:21 +0100 Subject: [PATCH 28/28] Make cancel generate forever let the current gen complete (#2206) --- javascript/contextMenus.js | 4 ---- 1 file changed, 4 deletions(-) diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 2d82269fc..7852793c1 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -147,10 +147,6 @@ generateOnRepeatId = appendContextMenuOption('#txt2img_generate','Generate forev cancelGenerateForever = function(){ clearInterval(window.generateOnRepeatInterval) - let interruptbutton = gradioApp().querySelector('#txt2img_interrupt'); - if(interruptbutton.offsetParent){ - interruptbutton.click(); - } } appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)