Merge branch 'batch-seed-attempt'
This commit is contained in:
commit
83bce1a604
|
@ -48,3 +48,13 @@ def randn(seed, shape):
|
|||
torch.manual_seed(seed)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
|
|
@ -119,8 +119,14 @@ def slerp(val, low, high):
|
|||
return res
|
||||
|
||||
|
||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
||||
xs = []
|
||||
|
||||
if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
|
||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
||||
else:
|
||||
sampler_noises = None
|
||||
|
||||
for i, seed in enumerate(seeds):
|
||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||
|
||||
|
@ -155,9 +161,17 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
||||
noise = x
|
||||
|
||||
if sampler_noises is not None:
|
||||
cnt = p.sampler.number_of_needed_noises(p)
|
||||
|
||||
for j in range(cnt):
|
||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||
|
||||
xs.append(noise)
|
||||
|
||||
if sampler_noises is not None:
|
||||
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
||||
|
||||
x = torch.stack(xs).to(shared.device)
|
||||
return x
|
||||
|
||||
|
@ -257,7 +271,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
comments[comment] = 1
|
||||
|
||||
# we manually generate all input noises because each one should have a specific seed
|
||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
|
||||
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
|
|
@ -80,8 +80,12 @@ class VanillaStableDiffusionSampler:
|
|||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
@ -185,16 +189,46 @@ def extended_trange(count, *args, **kwargs):
|
|||
shared.total_tqdm.update()
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, kdiff_sampler):
|
||||
self.kdiff_sampler = kdiff_sampler
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.kdiff_sampler.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.sampler_noise_index = 0
|
||||
|
||||
def callback_state(self, d):
|
||||
store_latent(d["denoised"])
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def randn_like(self, x):
|
||||
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
|
||||
|
||||
if noise is not None and x.shape == noise.shape:
|
||||
res = noise
|
||||
else:
|
||||
res = torch.randn_like(x)
|
||||
|
||||
self.sampler_noise_index += 1
|
||||
return res
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
|
||||
sigmas = self.model_wrap.get_sigmas(p.steps)
|
||||
|
@ -213,6 +247,9 @@ class KDiffusionSampler:
|
|||
if hasattr(k_diffusion.sampling, 'trange'):
|
||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|
||||
|
||||
if self.sampler_noises is not None:
|
||||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
|
||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning):
|
||||
|
@ -224,6 +261,9 @@ class KDiffusionSampler:
|
|||
if hasattr(k_diffusion.sampling, 'trange'):
|
||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|
||||
|
||||
if self.sampler_noises is not None:
|
||||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
|
||||
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
||||
return samples_ddim
|
||||
|
||||
|
|
|
@ -128,6 +128,7 @@ class Options:
|
|||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
|
|
Loading…
Reference in New Issue