prompt editing
This commit is contained in:
parent
b28cf84c36
commit
f2693bec08
|
@ -12,7 +12,7 @@ import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices
|
from modules import devices, prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
@ -247,8 +247,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
||||||
c = p.sd_model.get_learned_conditioning(prompts)
|
#c = p.sd_model.get_learned_conditioning(prompts)
|
||||||
|
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
|
||||||
|
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
import re
|
||||||
|
from collections import namedtuple
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
re_prompt = re.compile(r'''
|
||||||
|
(.*?)
|
||||||
|
\[
|
||||||
|
([^]:]+):
|
||||||
|
(?:([^]:]*):)?
|
||||||
|
([0-9]*\.?[0-9]+)
|
||||||
|
]
|
||||||
|
|
|
||||||
|
(.+)
|
||||||
|
''', re.X)
|
||||||
|
|
||||||
|
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
||||||
|
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||||
|
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
||||||
|
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
||||||
|
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
|
||||||
|
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
||||||
|
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
||||||
|
|
||||||
|
|
||||||
|
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
|
res = []
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
prompt_schedule: list[list[str | int]] = [[steps, ""]]
|
||||||
|
|
||||||
|
cached = cache.get(prompt, None)
|
||||||
|
if cached is not None:
|
||||||
|
res.append(cached)
|
||||||
|
|
||||||
|
for m in re_prompt.finditer(prompt):
|
||||||
|
plaintext = m.group(1) if m.group(5) is None else m.group(5)
|
||||||
|
concept_from = m.group(2)
|
||||||
|
concept_to = m.group(3)
|
||||||
|
if concept_to is None:
|
||||||
|
concept_to = concept_from
|
||||||
|
concept_from = ""
|
||||||
|
swap_position = float(m.group(4)) if m.group(4) is not None else None
|
||||||
|
|
||||||
|
if swap_position is not None:
|
||||||
|
if swap_position < 1:
|
||||||
|
swap_position = swap_position * steps
|
||||||
|
swap_position = int(min(swap_position, steps))
|
||||||
|
|
||||||
|
swap_index = None
|
||||||
|
found_exact_index = False
|
||||||
|
for i in range(len(prompt_schedule)):
|
||||||
|
end_step = prompt_schedule[i][0]
|
||||||
|
prompt_schedule[i][1] += plaintext
|
||||||
|
|
||||||
|
if swap_position is not None and swap_index is None:
|
||||||
|
if swap_position == end_step:
|
||||||
|
swap_index = i
|
||||||
|
found_exact_index = True
|
||||||
|
|
||||||
|
if swap_position < end_step:
|
||||||
|
swap_index = i
|
||||||
|
|
||||||
|
if swap_index is not None:
|
||||||
|
if not found_exact_index:
|
||||||
|
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
|
||||||
|
|
||||||
|
for i in range(len(prompt_schedule)):
|
||||||
|
end_step = prompt_schedule[i][0]
|
||||||
|
must_replace = swap_position < end_step
|
||||||
|
|
||||||
|
prompt_schedule[i][1] += concept_to if must_replace else concept_from
|
||||||
|
|
||||||
|
res.append(prompt_schedule)
|
||||||
|
cache[prompt] = prompt_schedule
|
||||||
|
#for t in prompt_schedule:
|
||||||
|
# print(t)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||||
|
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_learned_conditioning(prompts, steps):
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
||||||
|
|
||||||
|
cached = cache.get(prompt, None)
|
||||||
|
if cached is not None:
|
||||||
|
res.append(cached)
|
||||||
|
|
||||||
|
texts = [x[1] for x in prompt_schedule]
|
||||||
|
conds = shared.sd_model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
|
cond_schedule = []
|
||||||
|
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||||
|
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||||
|
|
||||||
|
cache[prompt] = cond_schedule
|
||||||
|
res.append(cond_schedule)
|
||||||
|
|
||||||
|
return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
||||||
|
res = torch.zeros(c.shape)
|
||||||
|
for i, cond_schedule in enumerate(c.schedules):
|
||||||
|
target_index = 0
|
||||||
|
for curret_index, (end_at, cond) in enumerate(cond_schedule):
|
||||||
|
if current_step <= end_at:
|
||||||
|
target_index = curret_index
|
||||||
|
break
|
||||||
|
res[i] = cond_schedule[target_index].cond
|
||||||
|
|
||||||
|
return res.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
|
|
@ -7,6 +7,7 @@ from PIL import Image
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
|
from modules import prompt_parser
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -53,20 +54,6 @@ def store_latent(decoded):
|
||||||
shared.state.current_image = sample_to_image(decoded)
|
shared.state.current_image = sample_to_image(decoded)
|
||||||
|
|
||||||
|
|
||||||
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
|
|
||||||
if sampler_wrapper.mask is not None:
|
|
||||||
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
|
|
||||||
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
|
|
||||||
|
|
||||||
res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
|
|
||||||
|
|
||||||
if sampler_wrapper.mask is not None:
|
|
||||||
store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
|
|
||||||
else:
|
|
||||||
store_latent(res[1])
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
||||||
state.sampling_steps = len(sequence)
|
state.sampling_steps = len(sequence)
|
||||||
|
@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler:
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
|
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
||||||
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
|
|
||||||
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
||||||
|
else:
|
||||||
|
store_latent(res[1])
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
return res
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
||||||
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
|
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
|
||||||
|
@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
|
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
|
||||||
self.mask = p.mask
|
self.mask = p.mask
|
||||||
self.nmask = p.nmask
|
self.nmask = p.nmask
|
||||||
self.init_latent = p.init_latent
|
self.init_latent = p.init_latent
|
||||||
|
@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler:
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning):
|
def sample(self, p, x, conditioning, unconditional_conditioning):
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||||
if hasattr(self.sampler, fieldname):
|
if hasattr(self.sampler, fieldname):
|
||||||
setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
|
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
||||||
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue