Merge branch 'dev' into test-fp8
This commit is contained in:
commit
0fb34b57b8
|
@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
self.lin_module = None
|
self.lin_module = None
|
||||||
self.org_module: list[torch.Module] = [self.sd_module]
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
|
self.scale = 1.0
|
||||||
|
|
||||||
# kohya-ss
|
# kohya-ss
|
||||||
if "oft_blocks" in weights.w.keys():
|
if "oft_blocks" in weights.w.keys():
|
||||||
self.is_kohya = True
|
self.is_kohya = True
|
||||||
|
@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
self.constraint = None
|
self.constraint = None
|
||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
def calc_updown_kb(self, orig_weight, multiplier):
|
def calc_updown(self, orig_weight):
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
if self.is_kohya:
|
||||||
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
|
|
||||||
|
|
||||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
|
@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
output_shape = orig_weight.shape
|
output_shape = orig_weight.shape
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
|
||||||
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
|
|
||||||
multiplier = self.multiplier()
|
|
||||||
return self.calc_updown_kb(orig_weight, multiplier)
|
|
||||||
|
|
||||||
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
|
||||||
if self.bias is not None:
|
|
||||||
updown = updown.reshape(self.bias.shape)
|
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
|
||||||
updown = updown.reshape(output_shape)
|
|
||||||
|
|
||||||
if len(output_shape) == 4:
|
|
||||||
updown = updown.reshape(output_shape)
|
|
||||||
|
|
||||||
if orig_weight.size().numel() == updown.size().numel():
|
|
||||||
updown = updown.reshape(orig_weight.shape)
|
|
||||||
|
|
||||||
if ex_bias is not None:
|
|
||||||
ex_bias = ex_bias * self.multiplier()
|
|
||||||
|
|
||||||
return updown, ex_bias
|
|
||||||
|
|
|
@ -159,7 +159,8 @@ def load_network(name, network_on_disk):
|
||||||
bundle_embeddings = {}
|
bundle_embeddings = {}
|
||||||
|
|
||||||
for key_network, weight in sd.items():
|
for key_network, weight in sd.items():
|
||||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
key_network_without_network_parts, _, network_part = key_network.partition(".")
|
||||||
|
|
||||||
if key_network_without_network_parts == "bundle_emb":
|
if key_network_without_network_parts == "bundle_emb":
|
||||||
emb_name, vec_name = network_part.split(".", 1)
|
emb_name, vec_name = network_part.split(".", 1)
|
||||||
emb_dict = bundle_embeddings.get(emb_name, {})
|
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||||
|
|
|
@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
|
||||||
self.setting_names = []
|
self.setting_names = []
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||||
|
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||||
|
|
||||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
||||||
|
|
||||||
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||||
|
|
||||||
|
@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i
|
||||||
"""),
|
"""),
|
||||||
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(),
|
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
|
||||||
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
|
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
|
||||||
if (modalImage && modalImage.offsetParent) {
|
if (modalImage && modalImage.offsetParent) {
|
||||||
let currentButton = selected_gallery_button();
|
let currentButton = selected_gallery_button();
|
||||||
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
||||||
if (preview.length > 0) {
|
if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
|
||||||
// show preview image if available
|
// show preview image if available
|
||||||
modalImage.src = preview[preview.length - 1].src;
|
modalImage.src = preview[preview.length - 1].src;
|
||||||
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||||
|
|
|
@ -215,9 +215,33 @@ function restoreProgressImg2img() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the width and height elements on `tabname` to accept
|
||||||
|
* pasting of resolutions in the form of "width x height".
|
||||||
|
*/
|
||||||
|
function setupResolutionPasting(tabname) {
|
||||||
|
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
|
||||||
|
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
|
||||||
|
for (const el of [width, height]) {
|
||||||
|
el.addEventListener('paste', function(event) {
|
||||||
|
var pasteData = event.clipboardData.getData('text/plain');
|
||||||
|
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
|
||||||
|
if (parsed) {
|
||||||
|
width.value = parsed[1];
|
||||||
|
height.value = parsed[2];
|
||||||
|
updateInput(width);
|
||||||
|
updateInput(height);
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onUiLoaded(function() {
|
onUiLoaded(function() {
|
||||||
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||||
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||||
|
setupResolutionPasting('txt2img');
|
||||||
|
setupResolutionPasting('img2img');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -791,3 +791,4 @@ def flatten(img, bgcolor):
|
||||||
img = background
|
img = background
|
||||||
|
|
||||||
return img.convert('RGB')
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
|
|
@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image):
|
||||||
return image.convert('RGB')
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
def apply_overlay(image, paste_loc, index, overlays):
|
def uncrop(image, dest_size, paste_loc):
|
||||||
if overlays is None or index >= len(overlays):
|
x, y, w, h = paste_loc
|
||||||
|
base_image = Image.new('RGBA', dest_size)
|
||||||
|
image = images.resize_image(1, image, w, h)
|
||||||
|
base_image.paste(image, (x, y))
|
||||||
|
image = base_image
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def apply_overlay(image, paste_loc, overlay):
|
||||||
|
if overlay is None:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
overlay = overlays[index]
|
|
||||||
|
|
||||||
if paste_loc is not None:
|
if paste_loc is not None:
|
||||||
x, y, w, h = paste_loc
|
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
|
||||||
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
|
||||||
image = images.resize_image(1, image, w, h)
|
|
||||||
base_image.paste(image, (x, y))
|
|
||||||
image = base_image
|
|
||||||
|
|
||||||
image = image.convert('RGBA')
|
image = image.convert('RGBA')
|
||||||
image.alpha_composite(overlay)
|
image.alpha_composite(overlay)
|
||||||
|
@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def create_binary_mask(image):
|
def create_binary_mask(image, round=True):
|
||||||
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
||||||
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
if round:
|
||||||
|
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||||
|
else:
|
||||||
|
image = image.split()[-1].convert("L")
|
||||||
else:
|
else:
|
||||||
image = image.convert('L')
|
image = image.convert('L')
|
||||||
return image
|
return image
|
||||||
|
@ -308,7 +315,7 @@ class StableDiffusionProcessing:
|
||||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
return c_adm
|
return c_adm
|
||||||
|
|
||||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||||
self.is_using_inpainting_conditioning = True
|
self.is_using_inpainting_conditioning = True
|
||||||
|
|
||||||
# Handle the different mask inputs
|
# Handle the different mask inputs
|
||||||
|
@ -320,8 +327,10 @@ class StableDiffusionProcessing:
|
||||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||||
|
|
||||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
if round_image_mask:
|
||||||
conditioning_mask = torch.round(conditioning_mask)
|
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
|
||||||
|
conditioning_mask = torch.round(conditioning_mask)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
||||||
|
|
||||||
|
@ -345,7 +354,7 @@ class StableDiffusionProcessing:
|
||||||
|
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||||
source_image = devices.cond_cast_float(source_image)
|
source_image = devices.cond_cast_float(source_image)
|
||||||
|
|
||||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||||
|
@ -357,7 +366,7 @@ class StableDiffusionProcessing:
|
||||||
return self.edit_image_conditioning(source_image)
|
return self.edit_image_conditioning(source_image)
|
||||||
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
|
||||||
|
|
||||||
if self.sampler.conditioning_key == "crossattn-adm":
|
if self.sampler.conditioning_key == "crossattn-adm":
|
||||||
return self.unclip_image_conditioning(source_image)
|
return self.unclip_image_conditioning(source_image)
|
||||||
|
@ -867,6 +876,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
ps = scripts.PostSampleArgs(samples_ddim)
|
||||||
|
p.scripts.post_sample(p, ps)
|
||||||
|
samples_ddim = ps.samples
|
||||||
|
|
||||||
if getattr(samples_ddim, 'already_decoded', False):
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
x_samples_ddim = samples_ddim
|
x_samples_ddim = samples_ddim
|
||||||
else:
|
else:
|
||||||
|
@ -922,13 +936,31 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
pp = scripts.PostprocessImageArgs(image)
|
pp = scripts.PostprocessImageArgs(image)
|
||||||
p.scripts.postprocess_image(p, pp)
|
p.scripts.postprocess_image(p, pp)
|
||||||
image = pp.image
|
image = pp.image
|
||||||
|
|
||||||
|
mask_for_overlay = getattr(p, "mask_for_overlay", None)
|
||||||
|
overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
|
||||||
|
p.scripts.postprocess_maskoverlay(p, ppmo)
|
||||||
|
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
|
||||||
|
|
||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
if save_samples and opts.save_images_before_color_correction:
|
if save_samples and opts.save_images_before_color_correction:
|
||||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
|
||||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||||
image = apply_color_correction(p.color_corrections[i], image)
|
image = apply_color_correction(p.color_corrections[i], image)
|
||||||
|
|
||||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
# If the intention is to show the output from the model
|
||||||
|
# that is being composited over the original image,
|
||||||
|
# we need to keep the original image around
|
||||||
|
# and use it in the composite step.
|
||||||
|
original_denoised_image = image.copy()
|
||||||
|
|
||||||
|
if p.paste_to is not None:
|
||||||
|
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
|
||||||
|
|
||||||
|
image = apply_overlay(image, p.paste_to, overlay_image)
|
||||||
|
|
||||||
if save_samples:
|
if save_samples:
|
||||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
||||||
|
@ -938,16 +970,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
|
||||||
|
if mask_for_overlay is not None:
|
||||||
if opts.return_mask or opts.save_mask:
|
if opts.return_mask or opts.save_mask:
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask = mask_for_overlay.convert('RGB')
|
||||||
if save_samples and opts.save_mask:
|
if save_samples and opts.save_mask:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||||
if opts.return_mask:
|
if opts.return_mask:
|
||||||
output_images.append(image_mask)
|
output_images.append(image_mask)
|
||||||
|
|
||||||
if opts.return_mask_composite or opts.save_mask_composite:
|
if opts.return_mask_composite or opts.save_mask_composite:
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
if save_samples and opts.save_mask_composite:
|
if save_samples and opts.save_mask_composite:
|
||||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||||
if opts.return_mask_composite:
|
if opts.return_mask_composite:
|
||||||
|
@ -1351,6 +1384,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
mask_blur_x: int = 4
|
mask_blur_x: int = 4
|
||||||
mask_blur_y: int = 4
|
mask_blur_y: int = 4
|
||||||
mask_blur: int = None
|
mask_blur: int = None
|
||||||
|
mask_round: bool = True
|
||||||
inpainting_fill: int = 0
|
inpainting_fill: int = 0
|
||||||
inpaint_full_res: bool = True
|
inpaint_full_res: bool = True
|
||||||
inpaint_full_res_padding: int = 0
|
inpaint_full_res_padding: int = 0
|
||||||
|
@ -1396,7 +1430,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
||||||
# but we still want to support binary masks.
|
# but we still want to support binary masks.
|
||||||
image_mask = create_binary_mask(image_mask)
|
image_mask = create_binary_mask(image_mask, round=self.mask_round)
|
||||||
|
|
||||||
if self.inpainting_mask_invert:
|
if self.inpainting_mask_invert:
|
||||||
image_mask = ImageOps.invert(image_mask)
|
image_mask = ImageOps.invert(image_mask)
|
||||||
|
@ -1503,7 +1537,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
||||||
latmask = latmask[0]
|
latmask = latmask[0]
|
||||||
latmask = np.around(latmask)
|
if self.mask_round:
|
||||||
|
latmask = np.around(latmask)
|
||||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||||
|
|
||||||
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
||||||
|
@ -1515,7 +1550,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
|
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
x = self.rng.next()
|
x = self.rng.next()
|
||||||
|
@ -1527,7 +1562,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
blended_samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.scripts is not None:
|
||||||
|
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
|
||||||
|
self.scripts.on_mask_blend(self, mba)
|
||||||
|
blended_samples = mba.blended_latent
|
||||||
|
|
||||||
|
samples = blended_samples
|
||||||
|
|
||||||
del x
|
del x
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
|
@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
|
class MaskBlendArgs:
|
||||||
|
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
|
||||||
|
self.current_latent = current_latent
|
||||||
|
self.nmask = nmask
|
||||||
|
self.init_latent = init_latent
|
||||||
|
self.mask = mask
|
||||||
|
self.blended_latent = blended_latent
|
||||||
|
|
||||||
|
self.denoiser = denoiser
|
||||||
|
self.is_final_blend = denoiser is None
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
class PostSampleArgs:
|
||||||
|
def __init__(self, samples):
|
||||||
|
self.samples = samples
|
||||||
|
|
||||||
class PostprocessImageArgs:
|
class PostprocessImageArgs:
|
||||||
def __init__(self, image):
|
def __init__(self, image):
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
|
class PostProcessMaskOverlayArgs:
|
||||||
|
def __init__(self, index, mask_for_overlay, overlay_image):
|
||||||
|
self.index = index
|
||||||
|
self.mask_for_overlay = mask_for_overlay
|
||||||
|
self.overlay_image = overlay_image
|
||||||
|
|
||||||
class PostprocessBatchListArgs:
|
class PostprocessBatchListArgs:
|
||||||
def __init__(self, images):
|
def __init__(self, images):
|
||||||
|
@ -206,6 +226,25 @@ class Script:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
|
||||||
|
"""
|
||||||
|
Called in inpainting mode when the original content is blended with the inpainted content.
|
||||||
|
This is called at every step in the denoising process and once at the end.
|
||||||
|
If is_final_blend is true, this is called for the final blending stage.
|
||||||
|
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_sample(self, p, ps: PostSampleArgs, *args):
|
||||||
|
"""
|
||||||
|
Called after the samples have been generated,
|
||||||
|
but before they have been decoded by the VAE, if applicable.
|
||||||
|
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||||
"""
|
"""
|
||||||
Called for every image after it has been generated.
|
Called for every image after it has been generated.
|
||||||
|
@ -213,6 +252,13 @@ class Script:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
|
||||||
|
"""
|
||||||
|
Called for every image after it has been generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess(self, p, processed, *args):
|
def postprocess(self, p, processed, *args):
|
||||||
"""
|
"""
|
||||||
This function is called after processing ends for AlwaysVisible scripts.
|
This function is called after processing ends for AlwaysVisible scripts.
|
||||||
|
@ -767,6 +813,22 @@ class ScriptRunner:
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def post_sample(self, p, ps: PostSampleArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.post_sample(p, ps, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def on_mask_blend(self, p, mba: MaskBlendArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.on_mask_blend(p, mba, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
|
@ -775,6 +837,14 @@ class ScriptRunner:
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_maskoverlay(p, ppmo, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
||||||
would be on the meta device.
|
would be on the meta device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if state_dict == sd:
|
if state_dict is sd:
|
||||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||||
|
|
||||||
original(module, state_dict, strict=strict)
|
original(module, state_dict, strict=strict)
|
||||||
|
|
|
@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.p = None
|
self.p = None
|
||||||
|
|
||||||
|
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
||||||
|
# as the original latents do not have noise
|
||||||
self.mask_before_denoising = False
|
self.mask_before_denoising = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
|
# If we use masks, blending between the denoised and original latent images occurs here.
|
||||||
|
def apply_blend(current_latent):
|
||||||
|
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.p.scripts is not None:
|
||||||
|
from modules import scripts
|
||||||
|
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
|
||||||
|
self.p.scripts.on_mask_blend(self.p, mba)
|
||||||
|
blended_latent = mba.blended_latent
|
||||||
|
|
||||||
|
return blended_latent
|
||||||
|
|
||||||
|
# Blend in the original latents (before)
|
||||||
if self.mask_before_denoising and self.mask is not None:
|
if self.mask_before_denoising and self.mask is not None:
|
||||||
x = self.init_latent * self.mask + self.nmask * x
|
x = apply_blend(x)
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
# Blend in the original latents (after)
|
||||||
if not self.mask_before_denoising and self.mask is not None:
|
if not self.mask_before_denoising and self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = apply_blend(denoised)
|
||||||
|
|
||||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
|
|
@ -258,6 +258,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
||||||
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
||||||
|
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -332,6 +333,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||||
|
"js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
||||||
|
|
|
@ -98,10 +98,8 @@ class StyleDatabase:
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
folder, file = os.path.split(self.path)
|
folder, file = os.path.split(self.path)
|
||||||
self.default_file = file.split("*")[0] + ".csv"
|
filename, _, ext = file.partition('*')
|
||||||
if self.default_file == ".csv":
|
self.default_path = os.path.join(folder, filename + ext)
|
||||||
self.default_file = "styles.csv"
|
|
||||||
self.default_path = os.path.join(folder, self.default_file)
|
|
||||||
|
|
||||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||||
|
|
||||||
|
@ -155,10 +153,8 @@ class StyleDatabase:
|
||||||
row["name"], prompt, negative_prompt, path
|
row["name"], prompt, negative_prompt, path
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_style_paths(self) -> list():
|
def get_style_paths(self) -> set:
|
||||||
"""
|
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
||||||
Returns a list of all distinct paths, including the default path, of
|
|
||||||
files that styles are loaded from."""
|
|
||||||
# Update any styles without a path to the default path
|
# Update any styles without a path to the default path
|
||||||
for style in list(self.styles.values()):
|
for style in list(self.styles.values()):
|
||||||
if not style.path:
|
if not style.path:
|
||||||
|
@ -172,9 +168,9 @@ class StyleDatabase:
|
||||||
style_paths.add(style.path)
|
style_paths.add(style.path)
|
||||||
|
|
||||||
# Remove any paths for styles that are just list dividers
|
# Remove any paths for styles that are just list dividers
|
||||||
style_paths.remove("do_not_save")
|
style_paths.discard("do_not_save")
|
||||||
|
|
||||||
return list(style_paths)
|
return style_paths
|
||||||
|
|
||||||
def get_style_prompts(self, styles):
|
def get_style_prompts(self, styles):
|
||||||
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||||
|
@ -196,20 +192,7 @@ class StyleDatabase:
|
||||||
# The path argument is deprecated, but kept for backwards compatibility
|
# The path argument is deprecated, but kept for backwards compatibility
|
||||||
_ = path
|
_ = path
|
||||||
|
|
||||||
# Update any styles without a path to the default path
|
style_paths = self.get_style_paths()
|
||||||
for style in list(self.styles.values()):
|
|
||||||
if not style.path:
|
|
||||||
self.styles[style.name] = style._replace(path=self.default_path)
|
|
||||||
|
|
||||||
# Create a list of all distinct paths, including the default path
|
|
||||||
style_paths = set()
|
|
||||||
style_paths.add(self.default_path)
|
|
||||||
for _, style in self.styles.items():
|
|
||||||
if style.path:
|
|
||||||
style_paths.add(style.path)
|
|
||||||
|
|
||||||
# Remove any paths for styles that are just list dividers
|
|
||||||
style_paths.remove("do_not_save")
|
|
||||||
|
|
||||||
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
||||||
|
|
||||||
|
|
|
@ -79,11 +79,11 @@ class Toprow:
|
||||||
def create_prompts(self):
|
def create_prompts(self):
|
||||||
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
|
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
|
||||||
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
|
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
|
||||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
|
||||||
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
|
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
|
||||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
|
||||||
|
|
||||||
self.prompt_img.change(
|
self.prompt_img.change(
|
||||||
fn=modules.images.image_data,
|
fn=modules.images.image_data,
|
||||||
|
|
|
@ -48,3 +48,12 @@ if has_xpu:
|
||||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
|
CondFunc('torch.bmm',
|
||||||
|
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||||
|
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||||
|
CondFunc('torch.cat',
|
||||||
|
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||||
|
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||||
|
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||||
|
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
|
||||||
|
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
|
||||||
|
|
29
script.js
29
script.js
|
@ -121,16 +121,22 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a ctrl+enter as a shortcut to start a generation
|
* Add keyboard shortcuts:
|
||||||
|
* Ctrl+Enter to start/restart a generation
|
||||||
|
* Alt/Option+Enter to skip a generation
|
||||||
|
* Esc to interrupt a generation
|
||||||
*/
|
*/
|
||||||
document.addEventListener('keydown', function(e) {
|
document.addEventListener('keydown', function(e) {
|
||||||
const isEnter = e.key === 'Enter' || e.keyCode === 13;
|
const isEnter = e.key === 'Enter' || e.keyCode === 13;
|
||||||
const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
|
const isCtrlKey = e.metaKey || e.ctrlKey;
|
||||||
|
const isAltKey = e.altKey;
|
||||||
|
const isEsc = e.key === 'Escape';
|
||||||
|
|
||||||
const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
|
|
||||||
const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||||
|
const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
|
||||||
|
const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]');
|
||||||
|
|
||||||
if (isEnter && isModifierKey) {
|
if (isCtrlKey && isEnter) {
|
||||||
if (interruptButton.style.display === 'block') {
|
if (interruptButton.style.display === 'block') {
|
||||||
interruptButton.click();
|
interruptButton.click();
|
||||||
const callback = (mutationList) => {
|
const callback = (mutationList) => {
|
||||||
|
@ -150,6 +156,21 @@ document.addEventListener('keydown', function(e) {
|
||||||
}
|
}
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isAltKey && isEnter) {
|
||||||
|
skipButton.click();
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isEsc) {
|
||||||
|
const globalPopup = document.querySelector('.global-popup');
|
||||||
|
const lightboxModal = document.querySelector('#lightboxModal');
|
||||||
|
if (!globalPopup || globalPopup.style.display === 'none') {
|
||||||
|
if (document.activeElement === lightboxModal) return;
|
||||||
|
interruptButton.click();
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -0,0 +1,747 @@
|
||||||
|
import numpy as np
|
||||||
|
import gradio as gr
|
||||||
|
import math
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
import modules.scripts as scripts
|
||||||
|
|
||||||
|
|
||||||
|
class SoftInpaintingSettings:
|
||||||
|
def __init__(self,
|
||||||
|
mask_blend_power,
|
||||||
|
mask_blend_scale,
|
||||||
|
inpaint_detail_preservation,
|
||||||
|
composite_mask_influence,
|
||||||
|
composite_difference_threshold,
|
||||||
|
composite_difference_contrast):
|
||||||
|
self.mask_blend_power = mask_blend_power
|
||||||
|
self.mask_blend_scale = mask_blend_scale
|
||||||
|
self.inpaint_detail_preservation = inpaint_detail_preservation
|
||||||
|
self.composite_mask_influence = composite_mask_influence
|
||||||
|
self.composite_difference_threshold = composite_difference_threshold
|
||||||
|
self.composite_difference_contrast = composite_difference_contrast
|
||||||
|
|
||||||
|
def add_generation_params(self, dest):
|
||||||
|
dest[enabled_gen_param_label] = True
|
||||||
|
dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
|
||||||
|
dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
|
||||||
|
dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
|
||||||
|
dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
|
||||||
|
dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
|
||||||
|
dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- Methods -------------------
|
||||||
|
|
||||||
|
def processing_uses_inpainting(p):
|
||||||
|
# TODO: Figure out a better way to determine if inpainting is being used by p
|
||||||
|
if getattr(p, "image_mask", None) is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if getattr(p, "mask", None) is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if getattr(p, "nmask", None) is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def latent_blend(settings, a, b, t):
|
||||||
|
"""
|
||||||
|
Interpolates two latent image representations according to the parameter t,
|
||||||
|
where the interpolated vectors' magnitudes are also interpolated separately.
|
||||||
|
The "detail_preservation" factor biases the magnitude interpolation towards
|
||||||
|
the larger of the two magnitudes.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# NOTE: We use inplace operations wherever possible.
|
||||||
|
|
||||||
|
# [4][w][h] to [1][4][w][h]
|
||||||
|
t2 = t.unsqueeze(0)
|
||||||
|
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
|
||||||
|
t3 = t[0].unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
one_minus_t2 = 1 - t2
|
||||||
|
one_minus_t3 = 1 - t3
|
||||||
|
|
||||||
|
# Linearly interpolate the image vectors.
|
||||||
|
a_scaled = a * one_minus_t2
|
||||||
|
b_scaled = b * t2
|
||||||
|
image_interp = a_scaled
|
||||||
|
image_interp.add_(b_scaled)
|
||||||
|
result_type = image_interp.dtype
|
||||||
|
del a_scaled, b_scaled, t2, one_minus_t2
|
||||||
|
|
||||||
|
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
|
||||||
|
# 64-bit operations are used here to allow large exponents.
|
||||||
|
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
|
||||||
|
|
||||||
|
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
||||||
|
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
|
||||||
|
settings.inpaint_detail_preservation) * one_minus_t3
|
||||||
|
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
|
||||||
|
settings.inpaint_detail_preservation) * t3
|
||||||
|
desired_magnitude = a_magnitude
|
||||||
|
desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
|
||||||
|
del a_magnitude, b_magnitude, t3, one_minus_t3
|
||||||
|
|
||||||
|
# Change the linearly interpolated image vectors' magnitudes to the value we want.
|
||||||
|
# This is the last 64-bit operation.
|
||||||
|
image_interp_scaling_factor = desired_magnitude
|
||||||
|
image_interp_scaling_factor.div_(current_magnitude)
|
||||||
|
image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
|
||||||
|
image_interp_scaled = image_interp
|
||||||
|
image_interp_scaled.mul_(image_interp_scaling_factor)
|
||||||
|
del current_magnitude
|
||||||
|
del desired_magnitude
|
||||||
|
del image_interp
|
||||||
|
del image_interp_scaling_factor
|
||||||
|
del result_type
|
||||||
|
|
||||||
|
return image_interp_scaled
|
||||||
|
|
||||||
|
|
||||||
|
def get_modified_nmask(settings, nmask, sigma):
|
||||||
|
"""
|
||||||
|
Converts a negative mask representing the transparency of the original latent vectors being overlayed
|
||||||
|
to a mask that is scaled according to the denoising strength for this step.
|
||||||
|
|
||||||
|
Where:
|
||||||
|
0 = fully opaque, infinite density, fully masked
|
||||||
|
1 = fully transparent, zero density, fully unmasked
|
||||||
|
|
||||||
|
We bring this transparency to a power, as this allows one to simulate N number of blending operations
|
||||||
|
where N can be any positive real value. Using this one can control the balance of influence between
|
||||||
|
the denoiser and the original latents according to the sigma value.
|
||||||
|
|
||||||
|
NOTE: "mask" is not used
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_adaptive_masks(
|
||||||
|
settings: SoftInpaintingSettings,
|
||||||
|
nmask,
|
||||||
|
latent_orig,
|
||||||
|
latent_processed,
|
||||||
|
overlay_images,
|
||||||
|
width, height,
|
||||||
|
paste_to):
|
||||||
|
import torch
|
||||||
|
import modules.processing as proc
|
||||||
|
import modules.images as images
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
|
||||||
|
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
|
||||||
|
latent_mask = nmask[0].float()
|
||||||
|
# convert the original mask into a form we use to scale distances for thresholding
|
||||||
|
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
|
||||||
|
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
|
||||||
|
+ mask_scalar * settings.composite_mask_influence)
|
||||||
|
mask_scalar = mask_scalar / (1.00001 - mask_scalar)
|
||||||
|
mask_scalar = mask_scalar.cpu().numpy()
|
||||||
|
|
||||||
|
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
|
||||||
|
|
||||||
|
kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
|
||||||
|
|
||||||
|
masks_for_overlay = []
|
||||||
|
|
||||||
|
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
|
||||||
|
converted_mask = distance_map.float().cpu().numpy()
|
||||||
|
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
|
||||||
|
percentile_min=0.9, percentile_max=1, min_width=1)
|
||||||
|
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
|
||||||
|
percentile_min=0.25, percentile_max=0.75, min_width=1)
|
||||||
|
|
||||||
|
# The distance at which opacity of original decreases to 50%
|
||||||
|
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
|
||||||
|
converted_mask = converted_mask / half_weighted_distance
|
||||||
|
|
||||||
|
converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
|
||||||
|
converted_mask = smootherstep(converted_mask)
|
||||||
|
converted_mask = 1 - converted_mask
|
||||||
|
converted_mask = 255. * converted_mask
|
||||||
|
converted_mask = converted_mask.astype(np.uint8)
|
||||||
|
converted_mask = Image.fromarray(converted_mask)
|
||||||
|
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||||
|
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||||
|
|
||||||
|
# Remove aliasing artifacts using a gaussian blur.
|
||||||
|
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||||
|
|
||||||
|
# Expand the mask to fit the whole image if needed.
|
||||||
|
if paste_to is not None:
|
||||||
|
converted_mask = proc.uncrop(converted_mask,
|
||||||
|
(overlay_image.width, overlay_image.height),
|
||||||
|
paste_to)
|
||||||
|
|
||||||
|
masks_for_overlay.append(converted_mask)
|
||||||
|
|
||||||
|
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||||
|
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||||
|
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||||
|
|
||||||
|
overlay_images[i] = image_masked.convert('RGBA')
|
||||||
|
|
||||||
|
return masks_for_overlay
|
||||||
|
|
||||||
|
|
||||||
|
def apply_masks(
|
||||||
|
settings,
|
||||||
|
nmask,
|
||||||
|
overlay_images,
|
||||||
|
width, height,
|
||||||
|
paste_to):
|
||||||
|
import torch
|
||||||
|
import modules.processing as proc
|
||||||
|
import modules.images as images
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
|
||||||
|
converted_mask = nmask[0].float()
|
||||||
|
converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
|
||||||
|
converted_mask = 255. * converted_mask
|
||||||
|
converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
|
||||||
|
converted_mask = Image.fromarray(converted_mask)
|
||||||
|
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||||
|
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||||
|
|
||||||
|
# Remove aliasing artifacts using a gaussian blur.
|
||||||
|
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||||
|
|
||||||
|
# Expand the mask to fit the whole image if needed.
|
||||||
|
if paste_to is not None:
|
||||||
|
converted_mask = proc.uncrop(converted_mask,
|
||||||
|
(width, height),
|
||||||
|
paste_to)
|
||||||
|
|
||||||
|
masks_for_overlay = []
|
||||||
|
|
||||||
|
for i, overlay_image in enumerate(overlay_images):
|
||||||
|
masks_for_overlay[i] = converted_mask
|
||||||
|
|
||||||
|
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||||
|
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||||
|
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||||
|
|
||||||
|
overlay_images[i] = image_masked.convert('RGBA')
|
||||||
|
|
||||||
|
return masks_for_overlay
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
||||||
|
"""
|
||||||
|
Generalization convolution filter capable of applying
|
||||||
|
weighted mean, median, maximum, and minimum filters
|
||||||
|
parametrically using an arbitrary kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (nparray):
|
||||||
|
The image, a 2-D array of floats, to which the filter is being applied.
|
||||||
|
kernel (nparray):
|
||||||
|
The kernel, a 2-D array of floats.
|
||||||
|
kernel_center (nparray):
|
||||||
|
The kernel center coordinate, a 1-D array with two elements.
|
||||||
|
percentile_min (float):
|
||||||
|
The lower bound of the histogram window used by the filter,
|
||||||
|
from 0 to 1.
|
||||||
|
percentile_max (float):
|
||||||
|
The upper bound of the histogram window used by the filter,
|
||||||
|
from 0 to 1.
|
||||||
|
min_width (float):
|
||||||
|
The minimum size of the histogram window bounds, in weight units.
|
||||||
|
Must be greater than 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Converts an index tuple into a vector.
|
||||||
|
def vec(x):
|
||||||
|
return np.array(x)
|
||||||
|
|
||||||
|
kernel_min = -kernel_center
|
||||||
|
kernel_max = vec(kernel.shape) - kernel_center
|
||||||
|
|
||||||
|
def weighted_histogram_filter_single(idx):
|
||||||
|
idx = vec(idx)
|
||||||
|
min_index = np.maximum(0, idx + kernel_min)
|
||||||
|
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
||||||
|
window_shape = max_index - min_index
|
||||||
|
|
||||||
|
class WeightedElement:
|
||||||
|
"""
|
||||||
|
An element of the histogram, its weight
|
||||||
|
and bounds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value, weight):
|
||||||
|
self.value: float = value
|
||||||
|
self.weight: float = weight
|
||||||
|
self.window_min: float = 0.0
|
||||||
|
self.window_max: float = 1.0
|
||||||
|
|
||||||
|
# Collect the values in the image as WeightedElements,
|
||||||
|
# weighted by their corresponding kernel values.
|
||||||
|
values = []
|
||||||
|
for window_tup in np.ndindex(tuple(window_shape)):
|
||||||
|
window_index = vec(window_tup)
|
||||||
|
image_index = window_index + min_index
|
||||||
|
centered_kernel_index = image_index - idx
|
||||||
|
kernel_index = centered_kernel_index + kernel_center
|
||||||
|
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
||||||
|
values.append(element)
|
||||||
|
|
||||||
|
def sort_key(x: WeightedElement):
|
||||||
|
return x.value
|
||||||
|
|
||||||
|
values.sort(key=sort_key)
|
||||||
|
|
||||||
|
# Calculate the height of the stack (sum)
|
||||||
|
# and each sample's range they occupy in the stack
|
||||||
|
sum = 0
|
||||||
|
for i in range(len(values)):
|
||||||
|
values[i].window_min = sum
|
||||||
|
sum += values[i].weight
|
||||||
|
values[i].window_max = sum
|
||||||
|
|
||||||
|
# Calculate what range of this stack ("window")
|
||||||
|
# we want to get the weighted average across.
|
||||||
|
window_min = sum * percentile_min
|
||||||
|
window_max = sum * percentile_max
|
||||||
|
window_width = window_max - window_min
|
||||||
|
|
||||||
|
# Ensure the window is within the stack and at least a certain size.
|
||||||
|
if window_width < min_width:
|
||||||
|
window_center = (window_min + window_max) / 2
|
||||||
|
window_min = window_center - min_width / 2
|
||||||
|
window_max = window_center + min_width / 2
|
||||||
|
|
||||||
|
if window_max > sum:
|
||||||
|
window_max = sum
|
||||||
|
window_min = sum - min_width
|
||||||
|
|
||||||
|
if window_min < 0:
|
||||||
|
window_min = 0
|
||||||
|
window_max = min_width
|
||||||
|
|
||||||
|
value = 0
|
||||||
|
value_weight = 0
|
||||||
|
|
||||||
|
# Get the weighted average of all the samples
|
||||||
|
# that overlap with the window, weighted
|
||||||
|
# by the size of their overlap.
|
||||||
|
for i in range(len(values)):
|
||||||
|
if window_min >= values[i].window_max:
|
||||||
|
continue
|
||||||
|
if window_max <= values[i].window_min:
|
||||||
|
break
|
||||||
|
|
||||||
|
s = max(window_min, values[i].window_min)
|
||||||
|
e = min(window_max, values[i].window_max)
|
||||||
|
w = e - s
|
||||||
|
|
||||||
|
value += values[i].value * w
|
||||||
|
value_weight += w
|
||||||
|
|
||||||
|
return value / value_weight if value_weight != 0 else 0
|
||||||
|
|
||||||
|
img_out = img.copy()
|
||||||
|
|
||||||
|
# Apply the kernel operation over each pixel.
|
||||||
|
for index in np.ndindex(img.shape):
|
||||||
|
img_out[index] = weighted_histogram_filter_single(index)
|
||||||
|
|
||||||
|
return img_out
|
||||||
|
|
||||||
|
|
||||||
|
def smoothstep(x):
|
||||||
|
"""
|
||||||
|
The smoothstep function, input should be clamped to 0-1 range.
|
||||||
|
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
|
||||||
|
"""
|
||||||
|
return x * x * (3 - 2 * x)
|
||||||
|
|
||||||
|
|
||||||
|
def smootherstep(x):
|
||||||
|
"""
|
||||||
|
The smootherstep function, input should be clamped to 0-1 range.
|
||||||
|
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
|
||||||
|
"""
|
||||||
|
return x * x * x * (x * (6 * x - 15) + 10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
|
||||||
|
"""
|
||||||
|
Creates a Gaussian kernel with thresholded edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stddev_radius (float):
|
||||||
|
Standard deviation of the gaussian kernel, in pixels.
|
||||||
|
max_radius (int):
|
||||||
|
The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
|
||||||
|
The kernel is thresholded so that any values one pixel beyond this radius
|
||||||
|
is weighted at 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
|
||||||
|
def gaussian(sqr_mag):
|
||||||
|
return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
|
||||||
|
|
||||||
|
# Helper function for converting a tuple to an array.
|
||||||
|
def vec(x):
|
||||||
|
return np.array(x)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Since a gaussian is unbounded, we need to limit ourselves
|
||||||
|
to a finite range.
|
||||||
|
We taper the ends off at the end of that range so they equal zero
|
||||||
|
while preserving the maximum value of 1 at the mean.
|
||||||
|
"""
|
||||||
|
zero_radius = max_radius + 1.0
|
||||||
|
gauss_zero = gaussian(zero_radius * zero_radius)
|
||||||
|
gauss_kernel_scale = 1 / (1 - gauss_zero)
|
||||||
|
|
||||||
|
def gaussian_kernel_func(coordinate):
|
||||||
|
x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
|
||||||
|
x = gaussian(x)
|
||||||
|
x -= gauss_zero
|
||||||
|
x *= gauss_kernel_scale
|
||||||
|
x = max(0.0, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
size = max_radius * 2 + 1
|
||||||
|
kernel_center = max_radius
|
||||||
|
kernel = np.zeros((size, size))
|
||||||
|
|
||||||
|
for index in np.ndindex(kernel.shape):
|
||||||
|
kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
|
||||||
|
|
||||||
|
return kernel, kernel_center
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- Constants -------------------
|
||||||
|
|
||||||
|
|
||||||
|
default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
|
||||||
|
|
||||||
|
enabled_ui_label = "Soft inpainting"
|
||||||
|
enabled_gen_param_label = "Soft inpainting enabled"
|
||||||
|
enabled_el_id = "soft_inpainting_enabled"
|
||||||
|
|
||||||
|
ui_labels = SoftInpaintingSettings(
|
||||||
|
"Schedule bias",
|
||||||
|
"Preservation strength",
|
||||||
|
"Transition contrast boost",
|
||||||
|
"Mask influence",
|
||||||
|
"Difference threshold",
|
||||||
|
"Difference contrast")
|
||||||
|
|
||||||
|
ui_info = SoftInpaintingSettings(
|
||||||
|
"Shifts when preservation of original content occurs during denoising.",
|
||||||
|
"How strongly partially masked content should be preserved.",
|
||||||
|
"Amplifies the contrast that may be lost in partially masked regions.",
|
||||||
|
"How strongly the original mask should bias the difference threshold.",
|
||||||
|
"How much an image region can change before the original pixels are not blended in anymore.",
|
||||||
|
"How sharp the transition should be between blended and not blended.")
|
||||||
|
|
||||||
|
gen_param_labels = SoftInpaintingSettings(
|
||||||
|
"Soft inpainting schedule bias",
|
||||||
|
"Soft inpainting preservation strength",
|
||||||
|
"Soft inpainting transition contrast boost",
|
||||||
|
"Soft inpainting mask influence",
|
||||||
|
"Soft inpainting difference threshold",
|
||||||
|
"Soft inpainting difference contrast")
|
||||||
|
|
||||||
|
el_ids = SoftInpaintingSettings(
|
||||||
|
"mask_blend_power",
|
||||||
|
"mask_blend_scale",
|
||||||
|
"inpaint_detail_preservation",
|
||||||
|
"composite_mask_influence",
|
||||||
|
"composite_difference_threshold",
|
||||||
|
"composite_difference_contrast")
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- Script -------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Script(scripts.Script):
|
||||||
|
def __init__(self):
|
||||||
|
self.section = "inpaint"
|
||||||
|
self.masks_for_overlay = None
|
||||||
|
self.overlay_images = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Soft Inpainting"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible if is_img2img else False
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
if not is_img2img:
|
||||||
|
return
|
||||||
|
|
||||||
|
with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
|
||||||
|
**High _Mask blur_** values are recommended!
|
||||||
|
""")
|
||||||
|
|
||||||
|
power = \
|
||||||
|
gr.Slider(label=ui_labels.mask_blend_power,
|
||||||
|
info=ui_info.mask_blend_power,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.1,
|
||||||
|
value=default.mask_blend_power,
|
||||||
|
elem_id=el_ids.mask_blend_power)
|
||||||
|
scale = \
|
||||||
|
gr.Slider(label=ui_labels.mask_blend_scale,
|
||||||
|
info=ui_info.mask_blend_scale,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.05,
|
||||||
|
value=default.mask_blend_scale,
|
||||||
|
elem_id=el_ids.mask_blend_scale)
|
||||||
|
detail = \
|
||||||
|
gr.Slider(label=ui_labels.inpaint_detail_preservation,
|
||||||
|
info=ui_info.inpaint_detail_preservation,
|
||||||
|
minimum=1,
|
||||||
|
maximum=32,
|
||||||
|
step=0.5,
|
||||||
|
value=default.inpaint_detail_preservation,
|
||||||
|
elem_id=el_ids.inpaint_detail_preservation)
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
### Pixel Composite Settings
|
||||||
|
""")
|
||||||
|
|
||||||
|
mask_inf = \
|
||||||
|
gr.Slider(label=ui_labels.composite_mask_influence,
|
||||||
|
info=ui_info.composite_mask_influence,
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
step=0.05,
|
||||||
|
value=default.composite_mask_influence,
|
||||||
|
elem_id=el_ids.composite_mask_influence)
|
||||||
|
|
||||||
|
dif_thresh = \
|
||||||
|
gr.Slider(label=ui_labels.composite_difference_threshold,
|
||||||
|
info=ui_info.composite_difference_threshold,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.25,
|
||||||
|
value=default.composite_difference_threshold,
|
||||||
|
elem_id=el_ids.composite_difference_threshold)
|
||||||
|
|
||||||
|
dif_contr = \
|
||||||
|
gr.Slider(label=ui_labels.composite_difference_contrast,
|
||||||
|
info=ui_info.composite_difference_contrast,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.25,
|
||||||
|
value=default.composite_difference_contrast,
|
||||||
|
elem_id=el_ids.composite_difference_contrast)
|
||||||
|
|
||||||
|
with gr.Accordion("Help", open=False):
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.mask_blend_power}
|
||||||
|
|
||||||
|
The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
|
||||||
|
This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
|
||||||
|
This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
|
||||||
|
|
||||||
|
- **Below 1**: Stronger preservation near the end (with low sigma)
|
||||||
|
- **1**: Balanced (proportional to sigma)
|
||||||
|
- **Above 1**: Stronger preservation in the beginning (with high sigma)
|
||||||
|
""")
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.mask_blend_scale}
|
||||||
|
|
||||||
|
Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
|
||||||
|
This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
|
||||||
|
|
||||||
|
- **Low values**: Favors generated content.
|
||||||
|
- **High values**: Favors original content.
|
||||||
|
""")
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.inpaint_detail_preservation}
|
||||||
|
|
||||||
|
This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
|
||||||
|
With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
|
||||||
|
This can prevent the loss of contrast that occurs with linear interpolation.
|
||||||
|
|
||||||
|
- **Low values**: Softer blending, details may fade.
|
||||||
|
- **High values**: Stronger contrast, may over-saturate colors.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
## Pixel Composite Settings
|
||||||
|
|
||||||
|
Masks are generated based on how much a part of the image changed after denoising.
|
||||||
|
These masks are used to blend the original and final images together.
|
||||||
|
If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.composite_mask_influence}
|
||||||
|
|
||||||
|
This parameter controls how much the mask should bias this sensitivity to difference.
|
||||||
|
|
||||||
|
- **0**: Ignore the mask, only consider differences in image content.
|
||||||
|
- **1**: Follow the mask closely despite image content changes.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.composite_difference_threshold}
|
||||||
|
|
||||||
|
This value represents the difference at which the original pixels will have less than 50% opacity.
|
||||||
|
|
||||||
|
- **Low values**: Two images patches must be almost the same in order to retain original pixels.
|
||||||
|
- **High values**: Two images patches can be very different and still retain original pixels.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.composite_difference_contrast}
|
||||||
|
|
||||||
|
This value represents the contrast between the opacity of the original and inpainted content.
|
||||||
|
|
||||||
|
- **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.
|
||||||
|
- **High values**: Ghosting will be less common, but transitions may be very sudden.
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
|
||||||
|
(power, gen_param_labels.mask_blend_power),
|
||||||
|
(scale, gen_param_labels.mask_blend_scale),
|
||||||
|
(detail, gen_param_labels.inpaint_detail_preservation),
|
||||||
|
(mask_inf, gen_param_labels.composite_mask_influence),
|
||||||
|
(dif_thresh, gen_param_labels.composite_difference_threshold),
|
||||||
|
(dif_contr, gen_param_labels.composite_difference_contrast)]
|
||||||
|
|
||||||
|
self.paste_field_names = []
|
||||||
|
for _, field_name in self.infotext_fields:
|
||||||
|
self.paste_field_names.append(field_name)
|
||||||
|
|
||||||
|
return [soft_inpainting_enabled,
|
||||||
|
power,
|
||||||
|
scale,
|
||||||
|
detail,
|
||||||
|
mask_inf,
|
||||||
|
dif_thresh,
|
||||||
|
dif_contr]
|
||||||
|
|
||||||
|
def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Shut off the rounding it normally does.
|
||||||
|
p.mask_round = False
|
||||||
|
|
||||||
|
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||||
|
|
||||||
|
# p.extra_generation_params["Mask rounding"] = False
|
||||||
|
settings.add_generation_params(p.extra_generation_params)
|
||||||
|
|
||||||
|
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
|
||||||
|
dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
if mba.is_final_blend:
|
||||||
|
mba.blended_latent = mba.current_latent
|
||||||
|
return
|
||||||
|
|
||||||
|
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||||
|
|
||||||
|
# todo: Why is sigma 2D? Both values are the same.
|
||||||
|
mba.blended_latent = latent_blend(settings,
|
||||||
|
mba.init_latent,
|
||||||
|
mba.current_latent,
|
||||||
|
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
|
||||||
|
|
||||||
|
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
|
||||||
|
dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
nmask = getattr(p, "nmask", None)
|
||||||
|
if nmask is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from modules import images
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||||
|
|
||||||
|
# since the original code puts holes in the existing overlay images,
|
||||||
|
# we have to rebuild them.
|
||||||
|
self.overlay_images = []
|
||||||
|
for img in p.init_images:
|
||||||
|
|
||||||
|
image = images.flatten(img, opts.img2img_background_color)
|
||||||
|
|
||||||
|
if p.paste_to is None and p.resize_mode != 3:
|
||||||
|
image = images.resize_image(p.resize_mode, image, p.width, p.height)
|
||||||
|
|
||||||
|
self.overlay_images.append(image.convert('RGBA'))
|
||||||
|
|
||||||
|
if len(p.init_images) == 1:
|
||||||
|
self.overlay_images = self.overlay_images * p.batch_size
|
||||||
|
|
||||||
|
if getattr(ps.samples, 'already_decoded', False):
|
||||||
|
self.masks_for_overlay = apply_masks(settings=settings,
|
||||||
|
nmask=nmask,
|
||||||
|
overlay_images=self.overlay_images,
|
||||||
|
width=p.width,
|
||||||
|
height=p.height,
|
||||||
|
paste_to=p.paste_to)
|
||||||
|
else:
|
||||||
|
self.masks_for_overlay = apply_adaptive_masks(settings=settings,
|
||||||
|
nmask=nmask,
|
||||||
|
latent_orig=p.init_latent,
|
||||||
|
latent_processed=ps.samples,
|
||||||
|
overlay_images=self.overlay_images,
|
||||||
|
width=p.width,
|
||||||
|
height=p.height,
|
||||||
|
paste_to=p.paste_to)
|
||||||
|
|
||||||
|
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
|
||||||
|
detail_preservation, mask_inf, dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.masks_for_overlay is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.overlay_images is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
|
||||||
|
ppmo.overlay_image = self.overlay_images[ppmo.index]
|
5
webui.sh
5
webui.sh
|
@ -133,7 +133,7 @@ case "$gpu_info" in
|
||||||
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]]
|
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]]
|
||||||
then
|
then
|
||||||
# Navi users will still use torch 1.13 because 2.0 does not seem to work.
|
# Navi users will still use torch 1.13 because 2.0 does not seem to work.
|
||||||
export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
|
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
|
||||||
else
|
else
|
||||||
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
|
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
|
||||||
exit 1
|
exit 1
|
||||||
|
@ -143,8 +143,7 @@ case "$gpu_info" in
|
||||||
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||||
;;
|
;;
|
||||||
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
|
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
|
||||||
export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
|
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7"
|
||||||
# Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
|
|
||||||
;;
|
;;
|
||||||
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
|
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
|
|
Loading…
Reference in New Issue