add infotext entry for emphasis; put emphasis into a separate file, add an option to parse but still ignore emphasis
This commit is contained in:
parent
3732cf2f97
commit
e2b19900ec
|
@ -356,6 +356,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||
res["Cache FP16 weight for LoRA"] = False
|
||||
|
||||
if "Emphasis" not in res:
|
||||
res["Emphasis"] = "Original"
|
||||
|
||||
infotext_versions.backcompat(res)
|
||||
|
||||
for key in skip_fields:
|
||||
|
|
|
@ -455,6 +455,7 @@ class StableDiffusionProcessing:
|
|||
self.height,
|
||||
opts.fp8_storage,
|
||||
opts.cache_fp16_weight,
|
||||
opts.emphasis,
|
||||
)
|
||||
|
||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
class Emphasis:
|
||||
"""Emphasis class decides how to death with (emphasized:1.1) text in prompts"""
|
||||
|
||||
name: str = "Base"
|
||||
description: str = ""
|
||||
|
||||
tokens: list[list[int]]
|
||||
"""tokens from the chunk of the prompt"""
|
||||
|
||||
multipliers: torch.Tensor
|
||||
"""tensor with multipliers, once for each token"""
|
||||
|
||||
z: torch.Tensor
|
||||
"""output of cond transformers network (CLIP)"""
|
||||
|
||||
def after_transformers(self):
|
||||
"""Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmphasisNone(Emphasis):
|
||||
name = "None"
|
||||
description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
|
||||
|
||||
|
||||
class EmphasisIgnore(Emphasis):
|
||||
name = "Ignore"
|
||||
description = "treat all empasised words as if they have no emphasis"
|
||||
|
||||
|
||||
class EmphasisOriginal(Emphasis):
|
||||
name = "Original"
|
||||
description = "the orginal emphasis implementation"
|
||||
|
||||
def after_transformers(self):
|
||||
original_mean = self.z.mean()
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
new_mean = self.z.mean()
|
||||
self.z = self.z * (original_mean / new_mean)
|
||||
|
||||
|
||||
class EmphasisOriginalNoNorm(EmphasisOriginal):
|
||||
name = "No norm"
|
||||
description = "same as orginal, but without normalization (seems to work better for SDXL)"
|
||||
|
||||
def after_transformers(self):
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
|
||||
def get_current_option(emphasis_option_name):
|
||||
return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)
|
||||
|
||||
|
||||
def get_options_descriptions():
|
||||
return ", ".join(f"{x.name}: {x.description}" for x in options)
|
||||
|
||||
|
||||
options = [
|
||||
EmphasisNone,
|
||||
EmphasisIgnore,
|
||||
EmphasisOriginal,
|
||||
EmphasisOriginalNoNorm,
|
||||
]
|
|
@ -3,7 +3,7 @@ from collections import namedtuple
|
|||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices, sd_hijack
|
||||
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
|
@ -88,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
Returns the list and the total number of tokens in the prompt.
|
||||
"""
|
||||
|
||||
if opts.enable_emphasis:
|
||||
if opts.emphasis != "None":
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
@ -249,6 +249,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
else:
|
||||
|
@ -274,14 +277,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
|
||||
pooled = getattr(z, 'pooled', None)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
emphasis = sd_emphasis.get_current_option(opts.emphasis)()
|
||||
emphasis.tokens = remade_batch_tokens
|
||||
emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
emphasis.z = z
|
||||
|
||||
if not getattr(opts, "disable_normalize_embeddings", False):
|
||||
z = z * (original_mean / new_mean)
|
||||
emphasis.after_transformers()
|
||||
|
||||
z = emphasis.z
|
||||
|
||||
if pooled is not None:
|
||||
z.pooled = pooled
|
||||
|
|
|
@ -32,7 +32,7 @@ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase,
|
|||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||
mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util, sd_emphasis
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||
|
@ -154,8 +154,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
|||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"disable_normalize_embeddings": OptionInfo(False, "Disable normalize embeddings").info("Do not normalize embeddings after calculating emphasis. It can be expected to be effective in preventing artifacts in SDXL."),
|
||||
"emphasis": OptionInfo("Original", "Emphasis mode", gr.Radio, lambda: {"choices": [x.name for x in sd_emphasis.options]}, infotext="Emphasis").info("makes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt; " + sd_emphasis.get_options_descriptions()),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
|
|
Loading…
Reference in New Issue