diff --git a/modules/processing.py b/modules/processing.py index 8180c63d8..bb94033b1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -84,7 +84,7 @@ class StableDiffusionProcessing: self.s_tmin = opts.s_tmin self.s_tmax = float('inf') # not representable as a standard ui option self.s_noise = opts.s_noise - + if not seed_enable_extras: self.subseed = -1 self.subseed_strength = 0 @@ -296,7 +296,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: assert(len(p.prompt) > 0) else: assert p.prompt is not None - + devices.torch_gc() seed = get_fixed_seed(p.seed) @@ -359,8 +359,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(prompts, p.steps) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -527,7 +527,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) return samples diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 5d58c4ed9..a3b124219 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,10 +1,7 @@ import re from collections import namedtuple -import torch -from lark import Lark, Transformer, Visitor -import functools -import modules.shared as shared +import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): @@ -14,25 +11,48 @@ import modules.shared as shared # [75, 'fantasy landscape with a lake and an oak in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] +schedule_parser = lark.Lark(r""" +!start: (prompt | /[][():]/+)* +prompt: (emphasized | scheduled | plain | WHITESPACE)* +!emphasized: "(" prompt ")" + | "(" prompt ":" prompt ")" + | "[" prompt "]" +scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" +WHITESPACE: /\s+/ +plain: /([^\\\[\]():]|\\.)+/ +%import common.SIGNED_NUMBER -> NUMBER +""") def get_learned_conditioning_prompt_schedules(prompts, steps): - grammar = r""" - start: prompt - prompt: (emphasized | scheduled | weighted | plain)* - !emphasized: "(" prompt ")" - | "(" prompt ":" prompt ")" - | "[" prompt "]" - scheduled: "[" (prompt ":")? prompt ":" NUMBER "]" - !weighted: "{" weighted_item ("|" weighted_item)* "}" - !weighted_item: prompt (":" prompt)? - plain: /([^\\\[\](){}:|]|\\.)+/ - %import common.SIGNED_NUMBER -> NUMBER """ - parser = Lark(grammar, parser='lalr') + >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] + >>> g("test") + [[10, 'test']] + >>> g("a [b:3]") + [[3, 'a '], [10, 'a b']] + >>> g("a [b: 3]") + [[3, 'a '], [10, 'a b']] + >>> g("a [[[b]]:2]") + [[2, 'a '], [10, 'a [[b]]']] + >>> g("[(a:2):3]") + [[3, ''], [10, '(a:2)']] + >>> g("a [b : c : 1] d") + [[1, 'a b d'], [10, 'a c d']] + >>> g("a[b:[c:d:2]:1]e") + [[1, 'abe'], [2, 'ace'], [10, 'ade']] + >>> g("a [unbalanced") + [[10, 'a [unbalanced']] + >>> g("a [b:.5] c") + [[5, 'a c'], [10, 'a b c']] + >>> g("a [{b|d{:.5] c") # not handling this right now + [[5, 'a c'], [10, 'a {b|d{ c']] + >>> g("((a][:b:c [d:3]") + [[3, '((a][:b:c '], [10, '((a][:b:c d']] + """ def collect_steps(steps, tree): l = [steps] - class CollectSteps(Visitor): + class CollectSteps(lark.Visitor): def scheduled(self, tree): tree.children[-1] = float(tree.children[-1]) if tree.children[-1] < 1: @@ -43,13 +63,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): return sorted(set(l)) def at_step(step, tree): - class AtStep(Transformer): + class AtStep(lark.Transformer): def scheduled(self, args): - if len(args) == 2: - before, after, when = (), *args - else: - before, after, when = args - yield before if step <= when else after + before, after, _, when = args + yield before or () if step <= when else after def start(self, args): def flatten(x): if type(x) == str: @@ -57,16 +74,22 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): else: for gen in x: yield from flatten(gen) - return ''.join(flatten(args[0])) + return ''.join(flatten(args)) def plain(self, args): yield args[0].value def __default__(self, data, children, meta): for child in children: yield from child return AtStep().transform(tree) - + def get_schedule(prompt): - tree = parser.parse(prompt) + try: + tree = schedule_parser.parse(prompt) + except lark.exceptions.LarkError as e: + if 0: + import traceback + traceback.print_exc() + return [[steps, prompt]] return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)} @@ -77,8 +100,7 @@ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) -def get_learned_conditioning(prompts, steps): - +def get_learned_conditioning(model, prompts, steps): res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -92,7 +114,7 @@ def get_learned_conditioning(prompts, steps): continue texts = [x[1] for x in prompt_schedule] - conds = shared.sd_model.get_learned_conditioning(texts) + conds = model.get_learned_conditioning(texts) cond_schedule = [] for i, (end_at_step, text) in enumerate(prompt_schedule): @@ -105,12 +127,13 @@ def get_learned_conditioning(prompts, steps): def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype) + param = c.schedules[0][0].cond + res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c.schedules): target_index = 0 - for curret_index, (end_at, cond) in enumerate(cond_schedule): + for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: - target_index = curret_index + target_index = current break res[i] = cond_schedule[target_index].cond @@ -148,23 +171,26 @@ def parse_prompt_attention(text): \\ - literal character '\' anything else - just text - Example: - - 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' - - produces: - - [ - ['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1] - ] + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] """ res = [] @@ -206,4 +232,19 @@ def parse_prompt_attention(text): if len(res) == 0: res = [["", 1.0]] + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + return res + +if __name__ == "__main__": + import doctest + doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) +else: + import torch # doctest faster