From bbdbbd36eda870cf0bd49fdf28476c78919a123e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 5 Oct 2022 04:43:05 +0100 Subject: [PATCH 01/28] shared.state.interrupt when restart is requested --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui.py b/modules/ui.py index de6342a48..523ab25b3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1210,6 +1210,7 @@ def create_ui(wrap_gradio_gpu_call): ) def request_restart(): + shared.state.interrupt() settings_interface.gradio_ref.do_restart = True restart_gradio.click( From 67d011b02eddc20202b654dfea56528de3d5edf7 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 5 Oct 2022 04:44:22 +0100 Subject: [PATCH 02/28] Show generation progress in window title --- javascript/progressbar.js | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 1e297abbe..3e3220c3f 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -4,6 +4,21 @@ global_progressbars = {} function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){ var progressbar = gradioApp().getElementById(id_progressbar) var interrupt = gradioApp().getElementById(id_interrupt) + + if(progressbar && progressbar.offsetParent){ + if(progressbar.innerText){ + let newtitle = 'Stable Diffusion - ' + progressbar.innerText + if(document.title != newtitle){ + document.title = newtitle; + } + }else{ + let newtitle = 'Stable Diffusion' + if(document.title != newtitle){ + document.title = newtitle; + } + } + } + if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ global_progressbars[id_progressbar] = progressbar From c26732fbee2a57e621ac22bf70decf7496daa4cd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 5 Oct 2022 23:16:27 +0300 Subject: [PATCH 03/28] added support for AND from https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ --- modules/processing.py | 2 +- modules/prompt_parser.py | 114 ++++++++++++++++++++++++++++++++++++--- modules/sd_samplers.py | 35 ++++++++---- modules/ui.py | 6 ++- 4 files changed, 138 insertions(+), 19 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index bb94033b1..d8c6b8d57 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -360,7 +360,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) + c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a3b124219..f7420daf9 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) def get_learned_conditioning(model, prompts, steps): + """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), + and the sampling step at which this condition is to be replaced by the next one. + + Input: + (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20) + + Output: + [ + [ + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) + ], + [ + ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')), + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')) + ] + ] + """ res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps): cache[prompt] = cond_schedule res.append(cond_schedule) - return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) + return res -def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - param = c.schedules[0][0].cond - res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) - for i, cond_schedule in enumerate(c.schedules): +re_AND = re.compile(r"\bAND\b") +re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$") + + +def get_multicond_prompt_list(prompts): + res_indexes = [] + + prompt_flat_list = [] + prompt_indexes = {} + + for prompt in prompts: + subprompts = re_AND.split(prompt) + + indexes = [] + for subprompt in subprompts: + text, weight = re_weight.search(subprompt).groups() + + weight = float(weight) if weight is not None else 1.0 + + index = prompt_indexes.get(text, None) + if index is None: + index = len(prompt_flat_list) + prompt_flat_list.append(text) + prompt_indexes[text] = index + + indexes.append((index, weight)) + + res_indexes.append(indexes) + + return res_indexes, prompt_flat_list, prompt_indexes + + +class ComposableScheduledPromptConditioning: + def __init__(self, schedules, weight=1.0): + self.schedules: list[ScheduledPromptConditioning] = schedules + self.weight: float = weight + + +class MulticondLearnedConditioning: + def __init__(self, shape, batch): + self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch + + +def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: + """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. + For each prompt, the list is obtained by splitting the prompt using the AND separator. + + https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ + """ + + res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) + + learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) + + res = [] + for indexes in res_indexes: + res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes]) + + return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) + + +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): + param = c[0][0].cond + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: @@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res +def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): + param = c.batch[0][0].schedules[0].cond + + tensors = [] + conds_list = [] + + for batch_no, composable_prompts in enumerate(c.batch): + conds_for_batch = [] + + for cond_index, composable_prompt in enumerate(composable_prompts): + target_index = 0 + for current, (end_at, cond) in enumerate(composable_prompt.schedules): + if current_step <= end_at: + target_index = current + break + + conds_for_batch.append((len(tensors), composable_prompt.weight)) + tensors.append(composable_prompt.schedules[target_index].cond) + + conds_list.append(conds_for_batch) + + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + + re_attention = re.compile(r""" \\\(| \\\)| diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index dbf570d2c..d27c547b3 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -109,9 +109,12 @@ class VanillaStableDiffusionSampler: return 0 def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' + cond = tensor + if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec @@ -183,19 +186,31 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + cond_in = torch.cat([tensor, uncond]) + if shared.batch_cond_uncond: - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale + x_out = self.inner_model(x_in, sigma_in, cond=cond_in) else: - uncond = self.inner_model(x, sigma, cond=uncond) - cond = self.inner_model(x, sigma, cond=cond) - denoised = uncond + (cond - uncond) * cond_scale + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + + denoised_uncond = x_out[-batch_size:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised diff --git a/modules/ui.py b/modules/ui.py index 523ab25b3..9620350fc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -34,7 +34,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste -from modules.prompt_parser import get_learned_conditioning_prompt_schedules +from modules import prompt_parser from modules.images import apply_filename_pattern, get_next_sequence_number import modules.textual_inversion.ui @@ -394,7 +394,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: - prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + except Exception: # a parsing error can happen here during typing, and we don't want to bother the user with # messages related to it in console From f8e41a96bb30a04dd5e294c7e1178c1c3b09d481 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 5 Oct 2022 23:52:05 +0300 Subject: [PATCH 04/28] fix various float parsing errors --- modules/prompt_parser.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f7420daf9..800b12c75 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -143,8 +143,7 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b") -re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$") - +re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") def get_multicond_prompt_list(prompts): res_indexes = [] From 20f8ec877a99ce2ebf193cb1e2e773cfc77b7c41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 00:09:32 +0300 Subject: [PATCH 05/28] remove type annotations in new code because presumably they don't work in 3.7 --- modules/prompt_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 800b12c75..ee4c5d02d 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -175,14 +175,14 @@ def get_multicond_prompt_list(prompts): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: list[ScheduledPromptConditioning] = schedules + self.schedules = schedules # : list[ScheduledPromptConditioning] self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: list[list[ComposableScheduledPromptConditioning]] = batch + self.batch = batch # : list[list[ComposableScheduledPromptConditioning]] def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: @@ -203,7 +203,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) -def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c, current_step): # c: list[list[ScheduledPromptConditioning]] param = c[0][0].cond res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c): From 34c358d10d52817f7a889ae4c52096ee654f3fe6 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 5 Oct 2022 22:11:30 +0100 Subject: [PATCH 06/28] use typing.list in prompt_parser.py for wider python version support --- modules/prompt_parser.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 800b12c75..fdfa21ae6 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,6 +1,6 @@ import re from collections import namedtuple - +from typing import List import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" @@ -175,14 +175,14 @@ def get_multicond_prompt_list(prompts): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: list[ScheduledPromptConditioning] = schedules + self.schedules: List[ScheduledPromptConditioning] = schedules self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: list[list[ComposableScheduledPromptConditioning]] = batch + self.batch: List[List[ComposableScheduledPromptConditioning]] = batch def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: @@ -203,7 +203,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) -def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c): From 55400c981b7c1389482057a35ed6ea11f08da194 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 03:11:15 +0100 Subject: [PATCH 07/28] Set gradio-img2img-tool default to 'editor' --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index e52c9b1d1..bab0fe6ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,7 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch") +parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) From 2499fb4e1910d31ff12c24110f161b20641b8835 Mon Sep 17 00:00:00 2001 From: Raphael Stoeckli Date: Wed, 5 Oct 2022 21:57:18 +0200 Subject: [PATCH 08/28] Add sanitizer for captions in Textual inversion --- modules/textual_inversion/preprocess.py | 28 +++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index f545a9937..4f3df4bd9 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,8 @@ +from cmath import log import os from PIL import Image, ImageOps +import platform +import sys import tqdm from modules import shared, images @@ -25,6 +28,7 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca def save_pic_with_caption(image, index): if process_caption: caption = "-" + shared.interrogator.generate_caption(image) + caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") else: caption = filename caption = os.path.splitext(caption)[0] @@ -75,3 +79,27 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca if process_caption: shared.interrogator.send_blip_to_ram() + +def sanitize_caption(base_path, original_caption, suffix): + operating_system = platform.system().lower() + if (operating_system == "windows"): + invalid_path_characters = "\\/:*?\"<>|" + max_path_length = 259 + else: + invalid_path_characters = "/" #linux/macos + max_path_length = 1023 + caption = original_caption + for invalid_character in invalid_path_characters: + caption = caption.replace(invalid_character, "") + fixed_path_length = len(base_path) + len(suffix) + if fixed_path_length + len(caption) <= max_path_length: + return caption + caption_tokens = caption.split() + new_caption = "" + for token in caption_tokens: + last_caption = new_caption + new_caption = new_caption + token + " " + if (len(new_caption) + fixed_path_length - 1 > max_path_length): + break + print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr) + return last_caption.strip() From 4288e53fc2ea25fa49715bf5b7f14603553c9e38 Mon Sep 17 00:00:00 2001 From: Raphael Stoeckli Date: Wed, 5 Oct 2022 23:11:32 +0200 Subject: [PATCH 09/28] removed unused import, fixed typo --- modules/textual_inversion/preprocess.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 4f3df4bd9..f1c002a2b 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,4 +1,3 @@ -from cmath import log import os from PIL import Image, ImageOps import platform @@ -13,7 +12,7 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - assert src != dst, 'same directory specified as source and desitnation' + assert src != dst, 'same directory specified as source and destination' os.makedirs(dst, exist_ok=True) From a93c3ffbfd264ed6b5d989922352300c9d3efbe4 Mon Sep 17 00:00:00 2001 From: Jocke Date: Wed, 5 Oct 2022 16:31:48 +0200 Subject: [PATCH 10/28] Outpainting mk2, prevent generation of a completely random image every time even when global seed is static --- scripts/outpainting_mk_2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 11613ca36..a6468e09a 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -85,8 +85,11 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0 src_dist = np.absolute(src_fft) src_phase = src_fft / src_dist + # create a generator with a static seed to make outpainting deterministic / only follow global seed + rng = np.random.default_rng(0) + noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise - noise_rgb = np.random.random_sample((width, height, num_channels)) + noise_rgb = rng.random((width, height, num_channels)) noise_grey = (np.sum(noise_rgb, axis=2) / 3.) noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter for c in range(num_channels): From 6e7057b31b9762a9720282c7da486e4f264dee28 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 12:08:06 +0300 Subject: [PATCH 11/28] support for downloading new commit hash for git repos --- launch.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/launch.py b/launch.py index 57405feab..2f91f586c 100644 --- a/launch.py +++ b/launch.py @@ -86,6 +86,15 @@ def git_clone(url, dir, name, commithash=None): # TODO clone into temporary dir and move if successful if os.path.exists(dir): + if commithash is None: + return + + current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, "Couldn't determine {name}'s hash: {commithash}").strip() + if current_hash == commithash: + return + + run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") From 5f24b7bcf4a074fbdec757617fcd1bc82e76551b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 12:08:48 +0300 Subject: [PATCH 12/28] option to let users select which samplers they want to hide --- modules/processing.py | 13 ++++++------- modules/sd_samplers.py | 19 +++++++++++++++++-- modules/shared.py | 15 +++++++++------ webui.py | 4 +++- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d8c6b8d57..e01c8b3f6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,9 +11,8 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, sd_samplers from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.face_restoration @@ -110,7 +109,7 @@ class Processed: self.width = p.width self.height = p.height self.sampler_index = p.sampler_index - self.sampler = samplers[p.sampler_index].name + self.sampler = sd_samplers.samplers[p.sampler_index].name self.cfg_scale = p.cfg_scale self.steps = p.steps self.batch_size = p.batch_size @@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": samplers[p.sampler_index].name, + "Sampler": sd_samplers.samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), @@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d27c547b3..2e1f77153 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,12 +32,27 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -samplers = [ +all_samplers = [ *samplers_data_k_diffusion, SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), ] -samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']] + +samplers = [] +samplers_for_img2img = [] + + +def set_samplers(): + global samplers, samplers_for_img2img + + hidden = set(opts.hide_samplers) + hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive']) + + samplers = [x for x in all_samplers if x.name not in hidden] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + + +set_samplers() sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], diff --git a/modules/shared.py b/modules/shared.py index bab0fe6ee..ca2e4c742 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices +from modules import sd_samplers from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -238,14 +239,16 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) + class Options: data = None data_labels = options_templates diff --git a/webui.py b/webui.py index 47848ba58..9ef124274 100644 --- a/webui.py +++ b/webui.py @@ -2,7 +2,7 @@ import os import threading import time import importlib -from modules import devices +from modules import devices, sd_samplers from modules.paths import script_path import signal import threading @@ -109,6 +109,8 @@ def webui(): time.sleep(0.5) break + sd_samplers.set_samplers() + print('Reloading Custom Scripts') modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) print('Reloading modules: modules.ui') From 2d3ea42a2d1e909bbccdb6b49561b187c60a9402 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 13:21:12 +0300 Subject: [PATCH 13/28] workaround for a mysterious bug where prompt weights can't be matched --- modules/prompt_parser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a7a6aa314..f00256f28 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -156,7 +156,9 @@ def get_multicond_prompt_list(prompts): indexes = [] for subprompt in subprompts: - text, weight = re_weight.search(subprompt).groups() + match = re_weight.search(subprompt) + + text, weight = match.groups() if match is not None else (subprompt, 1.0) weight = float(weight) if weight is not None else 1.0 From 2a532804957e47bc36c67c8f5b104dcfa8e8f3f0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 13:21:32 +0300 Subject: [PATCH 14/28] reorder imports to fix the bug with k-diffusion on some version --- webui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/webui.py b/webui.py index 9ef124274..480360fe0 100644 --- a/webui.py +++ b/webui.py @@ -2,11 +2,12 @@ import os import threading import time import importlib -from modules import devices, sd_samplers -from modules.paths import script_path import signal import threading +from modules.paths import script_path + +from modules import devices, sd_samplers import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration From c30c06db207a580d76544fd10fc1e03cd58ce85e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 3 Oct 2022 12:48:16 +0300 Subject: [PATCH 15/28] update k-diffusion --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 2f91f586c..c2713c645 100644 --- a/launch.py +++ b/launch.py @@ -19,7 +19,7 @@ clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLI stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") -k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924") +k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "567e11f7062ba20ae32b5a8cd07fb0fc4b9410cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") From c1a068ed0acc788774afc1541ca69342fd1d94ad Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 3 Oct 2022 12:49:17 +0300 Subject: [PATCH 16/28] Create alternate_sampler_noise_schedules.py --- scripts/alternate_sampler_noise_schedules.py | 53 ++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 scripts/alternate_sampler_noise_schedules.py diff --git a/scripts/alternate_sampler_noise_schedules.py b/scripts/alternate_sampler_noise_schedules.py new file mode 100644 index 000000000..4f3ed8fb1 --- /dev/null +++ b/scripts/alternate_sampler_noise_schedules.py @@ -0,0 +1,53 @@ +import inspect +from modules.processing import Processed, process_images +import gradio as gr +import modules.scripts as scripts +import k_diffusion.sampling +import torch + + +class Script(scripts.Script): + + def title(self): + return "Alternate Sampler Noise Schedules" + + def ui(self, is_img2img): + noise_scheduler = gr.Dropdown(label="Noise Scheduler", choices=['Default','Karras','Exponential', 'Variance Preserving'], value='Default', type="index") + sched_smin = gr.Slider(value=0.1, label="Sigma min", minimum=0.0, maximum=100.0, step=0.5,) + sched_smax = gr.Slider(value=10.0, label="Sigma max", minimum=0.0, maximum=100.0, step=0.5) + sched_rho = gr.Slider(value=7.0, label="Sigma rho (Karras only)", minimum=7.0, maximum=100.0, step=0.5) + sched_beta_d = gr.Slider(value=19.9, label="Beta distribution (VP only)",minimum=0.0, maximum=40.0, step=0.5) + sched_beta_min = gr.Slider(value=0.1, label="Beta min (VP only)", minimum=0.0, maximum=40.0, step=0.1) + sched_eps_s = gr.Slider(value=0.001, label="Epsilon (VP only)", minimum=0.001, maximum=1.0, step=0.001) + + return [noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s] + + def run(self, p, noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s): + + noise_scheduler_func_name = ['-','get_sigmas_karras','get_sigmas_exponential','get_sigmas_vp'][noise_scheduler] + + base_params = { + "sigma_min":sched_smin, + "sigma_max":sched_smax, + "rho":sched_rho, + "beta_d":sched_beta_d, + "beta_min":sched_beta_min, + "eps_s":sched_eps_s, + "device":"cuda" if torch.cuda.is_available() else "cpu" + } + + if hasattr(k_diffusion.sampling,noise_scheduler_func_name): + + sigma_func = getattr(k_diffusion.sampling,noise_scheduler_func_name) + sigma_func_kwargs = {} + + for k,v in base_params.items(): + if k in inspect.signature(sigma_func).parameters: + sigma_func_kwargs[k] = v + + def substitute_noise_scheduler(n): + return sigma_func(n,**sigma_func_kwargs) + + p.sampler_noise_scheduler_override = substitute_noise_scheduler + + return process_images(p) From 71901b3d3bea1d035bf4a7229d19356b4b062151 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:30:57 +0300 Subject: [PATCH 17/28] add karras scheduling variants --- modules/sd_samplers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 2e1f77153..8d6eb7620 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -26,6 +26,17 @@ samplers_k_diffusion = [ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), ] +if opts.show_karras_scheduler_variants: + k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2 + k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral + k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms + samplers_k_diffusion_ka = [ + ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']), + ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']), + ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']), + ] + samplers_k_diffusion.extend(samplers_k_diffusion_ka) + samplers_data_k_diffusion = [ SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) for label, funcname, aliases in samplers_k_diffusion @@ -345,6 +356,8 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) + elif self.funcname.endswith('ka'): + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) x = x * sigmas[0] From 3ddf80a9db8793188e2fe9488233d2b272cceb33 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:31:51 +0300 Subject: [PATCH 18/28] add variant setting --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index ca2e4c742..9e4860a28 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,6 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { From a971e4a767118ec41ec0f129770122babfb16a16 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Thu, 6 Oct 2022 13:34:42 +0300 Subject: [PATCH 19/28] update k-diff once again --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index c2713c645..9fe0fd675 100644 --- a/launch.py +++ b/launch.py @@ -19,7 +19,7 @@ clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLI stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") -k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "567e11f7062ba20ae32b5a8cd07fb0fc4b9410cf") +k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") From 5993df24a1026225cb8af89237547c1d9101ce69 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 14:12:52 +0300 Subject: [PATCH 20/28] integrate the new samplers PR --- modules/processing.py | 7 ++- modules/sd_samplers.py | 59 ++++++++++---------- modules/shared.py | 1 - scripts/alternate_sampler_noise_schedules.py | 53 ------------------ scripts/img2imgalt.py | 3 +- 5 files changed, 36 insertions(+), 87 deletions(-) delete mode 100644 scripts/alternate_sampler_noise_schedules.py diff --git a/modules/processing.py b/modules/processing.py index e01c8b3f6..e567956ce 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -477,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -520,7 +520,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -555,7 +556,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8d6eb7620..497df9430 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -13,46 +13,46 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a']), - ('Euler', 'sample_euler', ['k_euler']), - ('LMS', 'sample_lms', ['k_lms']), - ('Heun', 'sample_heun', ['k_heun']), - ('DPM2', 'sample_dpm_2', ['k_dpm_2']), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), + ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ] -if opts.show_karras_scheduler_variants: - k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2 - k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral - k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms - samplers_k_diffusion_ka = [ - ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']), - ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']), - ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']), - ] - samplers_k_diffusion.extend(samplers_k_diffusion_ka) - samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) - for label, funcname, aliases in samplers_k_diffusion + SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion if hasattr(k_diffusion.sampling, funcname) ] all_samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), + SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), ] samplers = [] samplers_for_img2img = [] +def create_sampler_with_index(list_of_configs, index, model): + config = list_of_configs[index] + sampler = config.constructor(model) + sampler.config = config + + return sampler + + def set_samplers(): global samplers, samplers_for_img2img @@ -130,6 +130,7 @@ class VanillaStableDiffusionSampler: self.step = 0 self.eta = None self.default_eta = 0.0 + self.config = None def number_of_needed_noises(self, p): return 0 @@ -291,6 +292,7 @@ class KDiffusionSampler: self.stop_at = None self.eta = None self.default_eta = 1.0 + self.config = None def callback_state(self, d): store_latent(d["denoised"]) @@ -355,11 +357,12 @@ class KDiffusionSampler: steps = steps or p.steps if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.funcname.endswith('ka'): - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) + sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = self.model_wrap.get_sigmas(steps) + x = x * sigmas[0] extra_params_kwargs = self.initialize(p) diff --git a/modules/shared.py b/modules/shared.py index 9e4860a28..ca2e4c742 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,6 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/scripts/alternate_sampler_noise_schedules.py b/scripts/alternate_sampler_noise_schedules.py deleted file mode 100644 index 4f3ed8fb1..000000000 --- a/scripts/alternate_sampler_noise_schedules.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from modules.processing import Processed, process_images -import gradio as gr -import modules.scripts as scripts -import k_diffusion.sampling -import torch - - -class Script(scripts.Script): - - def title(self): - return "Alternate Sampler Noise Schedules" - - def ui(self, is_img2img): - noise_scheduler = gr.Dropdown(label="Noise Scheduler", choices=['Default','Karras','Exponential', 'Variance Preserving'], value='Default', type="index") - sched_smin = gr.Slider(value=0.1, label="Sigma min", minimum=0.0, maximum=100.0, step=0.5,) - sched_smax = gr.Slider(value=10.0, label="Sigma max", minimum=0.0, maximum=100.0, step=0.5) - sched_rho = gr.Slider(value=7.0, label="Sigma rho (Karras only)", minimum=7.0, maximum=100.0, step=0.5) - sched_beta_d = gr.Slider(value=19.9, label="Beta distribution (VP only)",minimum=0.0, maximum=40.0, step=0.5) - sched_beta_min = gr.Slider(value=0.1, label="Beta min (VP only)", minimum=0.0, maximum=40.0, step=0.1) - sched_eps_s = gr.Slider(value=0.001, label="Epsilon (VP only)", minimum=0.001, maximum=1.0, step=0.001) - - return [noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s] - - def run(self, p, noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s): - - noise_scheduler_func_name = ['-','get_sigmas_karras','get_sigmas_exponential','get_sigmas_vp'][noise_scheduler] - - base_params = { - "sigma_min":sched_smin, - "sigma_max":sched_smax, - "rho":sched_rho, - "beta_d":sched_beta_d, - "beta_min":sched_beta_min, - "eps_s":sched_eps_s, - "device":"cuda" if torch.cuda.is_available() else "cpu" - } - - if hasattr(k_diffusion.sampling,noise_scheduler_func_name): - - sigma_func = getattr(k_diffusion.sampling,noise_scheduler_func_name) - sigma_func_kwargs = {} - - for k,v in base_params.items(): - if k in inspect.signature(sigma_func).parameters: - sigma_func_kwargs[k] = v - - def substitute_noise_scheduler(n): - return sigma_func(n,**sigma_func_kwargs) - - p.sampler_noise_scheduler_override = substitute_noise_scheduler - - return process_images(p) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 0ef137f7d..f9894cb01 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -8,7 +8,6 @@ import gradio as gr from modules import processing, shared, sd_samplers, prompt_parser from modules.processing import Processed -from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state import torch @@ -159,7 +158,7 @@ class Script(scripts.Script): combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) - sampler = samplers[p.sampler_index].constructor(p.sd_model) + sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sigmas = sampler.model_wrap.get_sigmas(p.steps) From f5490674a8fd84162b4e80c045e675633afb9ee7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 17:41:49 +0300 Subject: [PATCH 21/28] fix bad output for error when updating a git repo --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 9fe0fd675..75edb66a9 100644 --- a/launch.py +++ b/launch.py @@ -89,7 +89,7 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, "Couldn't determine {name}'s hash: {commithash}").strip() + current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() if current_hash == commithash: return From be71115b1a1201d04f0e2a11e718fb31cbd26474 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:09:44 +0100 Subject: [PATCH 22/28] Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index ca2e4c742..9f7c6efe5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,6 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_progress_in_title": OptionInfo(False, "Show generation progress in window title."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { From c06298d1d003aa034007978ee7508af636c18124 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:10:38 +0100 Subject: [PATCH 23/28] add check for progress in title setting --- javascript/progressbar.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 3e3220c3f..f9e9290e2 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -5,7 +5,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte var progressbar = gradioApp().getElementById(id_progressbar) var interrupt = gradioApp().getElementById(id_interrupt) - if(progressbar && progressbar.offsetParent){ + if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ if(progressbar.innerText){ let newtitle = 'Stable Diffusion - ' + progressbar.innerText if(document.title != newtitle){ From fec71e4de24b65b0f205a3c071b71651bbcb0dfc Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:35:07 +0100 Subject: [PATCH 24/28] Default window title progress updates on --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 9f7c6efe5..5c16f0257 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_progress_in_title": OptionInfo(False, "Show generation progress in window title."), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { From 5d0e6ab8567bda2ee8f5ed31f332ca07c1b84b98 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:04:50 +0100 Subject: [PATCH 25/28] Allow escaping of commas in xy_grid --- scripts/xy_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 1237e754d..210829a79 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -168,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") +re_non_escaped_comma = re.compile(r"(? Date: Thu, 6 Oct 2022 11:55:21 +0100 Subject: [PATCH 26/28] use csv.reader --- scripts/xy_grid.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 210829a79..1a625898f 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -1,8 +1,9 @@ from collections import namedtuple from copy import copy -from itertools import permutations +from itertools import permutations, chain import random - +import csv +from io import StringIO from PIL import Image import numpy as np @@ -168,8 +169,6 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") -re_non_escaped_comma = re.compile(r"(? Date: Thu, 6 Oct 2022 12:32:17 +0100 Subject: [PATCH 27/28] strip() split comma delimited lines --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 1a625898f..ec27e58bc 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(chain.from_iterable(csv.reader(StringIO(s)))) + valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(s))))) if opt.type == int: valslist_ext = [] From 82eb8ea452b1e63535c58d15ec6db2ad2342faa8 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 15:22:51 +0100 Subject: [PATCH 28/28] Update xy_grid.py split vals not 's' from tests --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index ec27e58bc..210c7b6e9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(s))))) + valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(vals))))) if opt.type == int: valslist_ext = []