diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c378118..fa647020f 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - return updown, ex_bias diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index d22ed8438..985b2753b 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df625..ac2c3de46 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index e4dae91bc..625c5d148 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -34,7 +34,7 @@ function updateOnBackgroundChange() { if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (preview.length > 0) { + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { // show preview image if available modalImage.src = preview[preview.length - 1].src; } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { diff --git a/javascript/ui.js b/javascript/ui.js index 410fc44e3..18c9f891a 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -215,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); diff --git a/modules/images.py b/modules/images.py index daf4eebe4..16f9ae7cc 100644 --- a/modules/images.py +++ b/modules/images.py @@ -791,3 +791,4 @@ def flatten(img, bgcolor): img = background return img.convert('RGB') + diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f5..bea01ec68 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image): return image.convert('RGB') -def apply_overlay(image, paste_loc, index, overlays): - if overlays is None or index >= len(overlays): +def uncrop(image, dest_size, paste_loc): + x, y, w, h = paste_loc + base_image = Image.new('RGBA', dest_size) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + return image + + +def apply_overlay(image, paste_loc, overlay): + if overlay is None: return image - overlay = overlays[index] - if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = image.convert('RGBA') image.alpha_composite(overlay) @@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image -def create_binary_mask(image): +def create_binary_mask(image, round=True): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + if round: + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -308,7 +315,7 @@ class StableDiffusionProcessing: c_adm = torch.cat((c_adm, noise_level_emb), 1) return c_adm - def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -320,8 +327,10 @@ class StableDiffusionProcessing: conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) + if round_image_mask: + # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -345,7 +354,7 @@ class StableDiffusionProcessing: return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): source_image = devices.cond_cast_float(source_image) # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely @@ -357,7 +366,7 @@ class StableDiffusionProcessing: return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -867,6 +876,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = ps.samples + if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: @@ -922,13 +936,31 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image + + mask_for_overlay = getattr(p, "mask_for_overlay", None) + overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image + if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) + image_without_cc = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) - image = apply_overlay(image, p.paste_to, i, p.overlay_images) + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + original_denoised_image = image.copy() + + if p.paste_to is not None: + original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) + + image = apply_overlay(image, p.paste_to, overlay_image) if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) @@ -938,16 +970,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: + + if mask_for_overlay is not None: if opts.return_mask or opts.save_mask: - image_mask = p.mask_for_overlay.convert('RGB') + image_mask = mask_for_overlay.convert('RGB') if save_samples and opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") if opts.return_mask: output_images.append(image_mask) if opts.return_mask_composite or opts.save_mask_composite: - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') if save_samples and opts.save_mask_composite: images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") if opts.return_mask_composite: @@ -1351,6 +1384,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None + mask_round: bool = True inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 @@ -1396,7 +1430,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask) + image_mask = create_binary_mask(image_mask, round=self.mask_round) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1503,7 +1537,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - latmask = np.around(latmask) + if self.mask_round: + latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) @@ -1515,7 +1550,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() @@ -1527,7 +1562,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: - samples = samples * self.nmask + self.init_latent * self.mask + blended_samples = samples * self.nmask + self.init_latent * self.mask + + if self.scripts is not None: + mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent + + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 7f9454eb5..b6fcf96e0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, AlwaysVisible = object() +class MaskBlendArgs: + def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None): + self.current_latent = current_latent + self.nmask = nmask + self.init_latent = init_latent + self.mask = mask + self.blended_latent = blended_latent + + self.denoiser = denoiser + self.is_final_blend = denoiser is None + self.sigma = sigma + +class PostSampleArgs: + def __init__(self, samples): + self.samples = samples class PostprocessImageArgs: def __init__(self, image): self.image = image +class PostProcessMaskOverlayArgs: + def __init__(self, index, mask_for_overlay, overlay_image): + self.index = index + self.mask_for_overlay = mask_for_overlay + self.overlay_image = overlay_image class PostprocessBatchListArgs: def __init__(self, images): @@ -206,6 +226,25 @@ class Script: pass + def on_mask_blend(self, p, mba: MaskBlendArgs, *args): + """ + Called in inpainting mode when the original content is blended with the inpainted content. + This is called at every step in the denoising process and once at the end. + If is_final_blend is true, this is called for the final blending stage. + Otherwise, denoiser and sigma are defined and may be used to inform the procedure. + """ + + pass + + def post_sample(self, p, ps: PostSampleArgs, *args): + """ + Called after the samples have been generated, + but before they have been decoded by the VAE, if applicable. + Check getattr(samples, 'already_decoded', False) to test if the images are decoded. + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -213,6 +252,13 @@ class Script: pass + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -767,6 +813,22 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def post_sample(self, p, ps: PostSampleArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.post_sample(p, ps, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + + def on_mask_blend(self, p, mba: MaskBlendArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.on_mask_blend(p, mba, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: @@ -775,6 +837,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_maskoverlay(p, ppmo, *script_args) + except Exception: + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae..273a7edd8 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b8101d38d..eb9d5dafa 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module): self.sampler = sampler self.model_wrap = None self.p = None + + # NOTE: masking before denoising can cause the original latents to be oversmoothed + # as the original latents do not have noise self.mask_before_denoising = False @property @@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # If we use masks, blending between the denoised and original latent images occurs here. + def apply_blend(current_latent): + blended_latent = current_latent * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + return blended_latent + + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = self.init_latent * self.mask + self.nmask * x + x = apply_blend(x) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + denoised = apply_blend(denoised) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/shared_options.py b/modules/shared_options.py index a860e355e..19ba47886 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -258,6 +258,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) @@ -332,6 +333,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), { "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { diff --git a/modules/styles.py b/modules/styles.py index 7fb6c2e11..81d9800d1 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -98,10 +98,8 @@ class StyleDatabase: self.path = path folder, file = os.path.split(self.path) - self.default_file = file.split("*")[0] + ".csv" - if self.default_file == ".csv": - self.default_file = "styles.csv" - self.default_path = os.path.join(folder, self.default_file) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] @@ -155,10 +153,8 @@ class StyleDatabase: row["name"], prompt, negative_prompt, path ) - def get_style_paths(self) -> list(): - """ - Returns a list of all distinct paths, including the default path, of - files that styles are loaded from.""" + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: @@ -172,9 +168,9 @@ class StyleDatabase: style_paths.add(style.path) # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths.discard("do_not_save") - return list(style_paths) + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -196,20 +192,7 @@ class StyleDatabase: # The path argument is deprecated, but kept for backwards compatibility _ = path - # Update any styles without a path to the default path - for style in list(self.styles.values()): - if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) - - # Create a list of all distinct paths, including the default path - style_paths = set() - style_paths.add(self.default_path) - for _, style in self.styles.items(): - if style.path: - style_paths.add(style.path) - - # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths = self.get_style_paths() csv_names = [os.path.split(path)[1].lower() for path in style_paths] diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 88838f977..9caf8faa2 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ class Toprow: def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c7903..d8da94a0e 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,12 @@ if has_xpu: CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) diff --git a/script.js b/script.js index c0e678ea7..be1bc317e 100644 --- a/script.js +++ b/script.js @@ -121,16 +121,22 @@ document.addEventListener("DOMContentLoaded", function() { }); /** - * Add a ctrl+enter as a shortcut to start a generation + * Add keyboard shortcuts: + * Ctrl+Enter to start/restart a generation + * Alt/Option+Enter to skip a generation + * Esc to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; + const isCtrlKey = e.metaKey || e.ctrlKey; + const isAltKey = e.altKey; + const isEsc = e.key === 'Escape'; - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isEnter && isModifierKey) { + if (isCtrlKey && isEnter) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -150,6 +156,21 @@ document.addEventListener('keydown', function(e) { } e.preventDefault(); } + + if (isAltKey && isEnter) { + skipButton.click(); + e.preventDefault(); + } + + if (isEsc) { + const globalPopup = document.querySelector('.global-popup'); + const lightboxModal = document.querySelector('#lightboxModal'); + if (!globalPopup || globalPopup.style.display === 'none') { + if (document.activeElement === lightboxModal) return; + interruptButton.click(); + e.preventDefault(); + } + } }); /** diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py new file mode 100644 index 000000000..d90243442 --- /dev/null +++ b/scripts/soft_inpainting.py @@ -0,0 +1,747 @@ +import numpy as np +import gradio as gr +import math +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast + + +# ------------------- Methods ------------------- + +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + + +def latent_blend(settings, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(settings, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) + + +def apply_adaptive_masks( + settings: SoftInpaintingSettings, + nmask, + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + latent_mask = nmask[0].float() + # convert the original mask into a form we use to scale distances for thresholding + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) + converted_mask = smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + settings, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") + + +# ------------------- Script ------------------- + + +class Script(scripts.Script): + def __init__(self): + self.section = "inpaint" + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + power = \ + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power) + scale = \ + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale) + detail = \ + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the contrast between the opacity of the original and inpainted content. + + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] + + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if mba.is_final_blend: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return + + from modules import images + from modules.shared import opts + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(settings=settings, + nmask=nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=nmask, + latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] diff --git a/webui.sh b/webui.sh index cff433272..69ca2f88d 100755 --- a/webui.sh +++ b/webui.sh @@ -133,7 +133,7 @@ case "$gpu_info" in if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] then # Navi users will still use torch 1.13 because 2.0 does not seem to work. - export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" else printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" exit 1 @@ -143,8 +143,7 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6" - # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}"