diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index a29f38550..e6d9fa4f4 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -69,10 +69,14 @@ def setup_model(dirname): self.net = net self.face_helper = face_helper - self.net.to(devices.device_codeformer) return net, face_helper + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + def restore(self, np_image, w=None): np_image = np_image[:, :, ::-1] @@ -82,6 +86,8 @@ def setup_model(dirname): if self.net is None or self.face_helper is None: return np_image + self.send_model_to(devices.device_codeformer) + self.face_helper.clean_all() self.face_helper.read_image(np_image) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) @@ -113,8 +119,10 @@ def setup_model(dirname): if original_resolution != restored_img.shape[0:2]: restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) + self.face_helper.clean_all() + if shared.opts.face_restoration_unload: - self.net.to(devices.cpu) + self.send_model_to(devices.cpu) return restored_img diff --git a/modules/devices.py b/modules/devices.py index b78996322..0158b11fc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,3 +1,5 @@ +import contextlib + import torch from modules import errors @@ -56,3 +58,11 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) + +def autocast(): + from modules import shared + + if dtype == torch.float32 or shared.cmd_opts.precision == "full": + return contextlib.nullcontext() + + return torch.autocast("cuda") diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index f1a564b73..a9452dce5 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -36,23 +36,33 @@ def gfpgann(): else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) - model.gfpgan.to(devices.device_gfpgan) + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) loaded_gfpgan_model = model return model +def send_model_to(model, device): + model.gfpgan.to(device) + model.face_helper.face_det.to(device) + model.face_helper.face_parse.to(device) + + def gfpgan_fix_faces(np_image): model = gfpgann() if model is None: return np_image + + send_model_to(model, devices.device_gfpgan) + np_image_bgr = np_image[:, :, ::-1] cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] + model.face_helper.clean_all() + if shared.opts.face_restoration_unload: - model.gfpgan.to(devices.cpu) + send_model_to(model, devices.cpu) return np_image diff --git a/modules/processing.py b/modules/processing.py index 0a4b6198f..6f5599c7d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,3 @@ -import contextlib import json import math import os @@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext - ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + + with torch.no_grad(): p.init(all_prompts, all_seeds, all_subseeds) if state.job_count == -1: @@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) - uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(prompts, p.steps) + with devices.autocast(): + uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -361,13 +360,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + with devices.autocast(): + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + if state.interrupted: # if we are interruped, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent + samples_ddim = samples_ddim.to(devices.dtype) + x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -386,6 +389,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() image = Image.fromarray(x_sample) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index e811eb9ec..99c8ed99c 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,20 +1,11 @@ import re from collections import namedtuple import torch +from lark import Lark, Transformer, Visitor +import functools import modules.shared as shared -re_prompt = re.compile(r''' -(.*?) -\[ - ([^]:]+): - (?:([^]:]*):)? - ([0-9]*\.?[0-9]+) -] -| -(.+) -''', re.X) - # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -25,61 +16,57 @@ re_prompt = re.compile(r''' def get_learned_conditioning_prompt_schedules(prompts, steps): - res = [] - cache = {} - - for prompt in prompts: - prompt_schedule: list[list[str | int]] = [[steps, ""]] - - cached = cache.get(prompt, None) - if cached is not None: - res.append(cached) - continue - - for m in re_prompt.finditer(prompt): - plaintext = m.group(1) if m.group(5) is None else m.group(5) - concept_from = m.group(2) - concept_to = m.group(3) - if concept_to is None: - concept_to = concept_from - concept_from = "" - swap_position = float(m.group(4)) if m.group(4) is not None else None - - if swap_position is not None: - if swap_position < 1: - swap_position = swap_position * steps - swap_position = int(min(swap_position, steps)) - - swap_index = None - found_exact_index = False - for i in range(len(prompt_schedule)): - end_step = prompt_schedule[i][0] - prompt_schedule[i][1] += plaintext - - if swap_position is not None and swap_index is None: - if swap_position == end_step: - swap_index = i - found_exact_index = True - - if swap_position < end_step: - swap_index = i - - if swap_index is not None: - if not found_exact_index: - prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]]) - - for i in range(len(prompt_schedule)): - end_step = prompt_schedule[i][0] - must_replace = swap_position < end_step - - prompt_schedule[i][1] += concept_to if must_replace else concept_from - - res.append(prompt_schedule) - cache[prompt] = prompt_schedule - #for t in prompt_schedule: - # print(t) - - return res + grammar = r""" + start: prompt + prompt: (emphasized | scheduled | weighted | plain)* + !emphasized: "(" prompt ")" + | "(" prompt ":" prompt ")" + | "[" prompt "]" + scheduled: "[" (prompt ":")? prompt ":" NUMBER "]" + !weighted: "{" weighted_item ("|" weighted_item)* "}" + !weighted_item: prompt (":" prompt)? + plain: /([^\\\[\](){}:|]|\\.)+/ + %import common.SIGNED_NUMBER -> NUMBER + """ + parser = Lark(grammar, parser='lalr') + def collect_steps(steps, tree): + l = [steps] + class CollectSteps(Visitor): + def scheduled(self, tree): + tree.children[-1] = float(tree.children[-1]) + if tree.children[-1] < 1: + tree.children[-1] *= steps + tree.children[-1] = min(steps, int(tree.children[-1])) + l.append(tree.children[-1]) + CollectSteps().visit(tree) + return sorted(set(l)) + def at_step(step, tree): + class AtStep(Transformer): + def scheduled(self, args): + if len(args) == 2: + before, after, when = (), *args + else: + before, after, when = args + yield before if step <= when else after + def start(self, args): + def flatten(x): + if type(x) == str: + yield x + else: + for gen in x: + yield from flatten(gen) + return ''.join(flatten(args[0])) + def plain(self, args): + yield args[0].value + def __default__(self, data, children, meta): + for child in children: + yield from child + return AtStep().transform(tree) + @functools.cache + def get_schedule(prompt): + tree = parser.parse(prompt) + return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] + return [get_schedule(prompt) for prompt in prompts] ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) diff --git a/modules/ui.py b/modules/ui.py index 55f7aa953..20dc8c379 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -386,14 +386,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) + def update_token_counter(text, steps): - prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + try: + prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step,prompt_text in flat_prompts] + prompts = [prompt_text for step, prompt_text in flat_prompts] tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" + def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" diff --git a/requirements.txt b/requirements.txt index d4b337fce..631fe616a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ clean-fid resize-right torchdiffeq kornia +lark diff --git a/requirements_versions.txt b/requirements_versions.txt index 8a9acf205..fdff26878 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -21,3 +21,4 @@ clean-fid==0.1.29 resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 +lark==1.1.2