From d51847c184f62244a7b330f6115435e5d6155c84 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 14 Sep 2022 19:41:55 +0300 Subject: [PATCH] fix caching for img2imgalt --- scripts/img2imgalt.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 8ff4c2103..3153afa7a 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -1,3 +1,5 @@ +from collections import namedtuple + import numpy as np from tqdm import trange @@ -56,9 +58,14 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): return x / x.std() -cache = [None, None, None, None, None] + +Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt"]) + class Script(scripts.Script): + def __init__(self): + self.cache = None + def title(self): return "img2img alternative test" @@ -67,7 +74,7 @@ class Script(scripts.Script): def ui(self, is_img2img): original_prompt = gr.Textbox(label="Original prompt", lines=1) - cfg = gr.Slider(label="Decode CFG scale", minimum=0.1, maximum=3.0, step=0.1, value=1.0) + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0) st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) return [original_prompt, cfg, st] @@ -77,19 +84,18 @@ class Script(scripts.Script): p.batch_count = 1 def sample_extra(x, conditioning, unconditional_conditioning): - lat = tuple([int(x*10) for x in p.init_latent.cpu().numpy().flatten().tolist()]) + lat = (p.init_latent.cpu().numpy() * 10).astype(int) - if cache[0] is not None and cache[1] == cfg and cache[2] == st and len(cache[3]) == len(lat) and sum(np.array(cache[3])-np.array(lat)) < 100 and cache[4] == original_prompt: - noise = cache[0] + same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st and self.cache.original_prompt == original_prompt + same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 + + if same_everything: + noise = self.cache.noise else: shared.state.job_count += 1 cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) noise = find_noise_for_image(p, cond, unconditional_conditioning, cfg, st) - cache[0] = noise - cache[1] = cfg - cache[2] = st - cache[3] = lat - cache[4] = original_prompt + self.cache = Cached(noise, cfg, st, lat, original_prompt) sampler = samplers[p.sampler_index].constructor(p.sd_model)