rework RNG to use generators instead of generating noises beforehand
This commit is contained in:
parent
d81d3fa8cd
commit
0d5dc9a6e7
|
@ -3,7 +3,7 @@ import contextlib
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors, rng_philox
|
from modules import errors
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
|
@ -96,84 +96,6 @@ def cond_cast_float(input):
|
||||||
nv_rng = None
|
nv_rng = None
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
|
||||||
|
|
||||||
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
manual_seed(seed)
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_local(seed, shape):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
|
||||||
|
|
||||||
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
rng = rng_philox.Generator(seed)
|
|
||||||
return torch.asarray(rng.randn(shape), device=device)
|
|
||||||
|
|
||||||
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
|
|
||||||
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
|
||||||
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_like(x):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
|
||||||
|
|
||||||
Use either randn() or manual_seed() to initialize the generator."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
|
||||||
return torch.randn_like(x, device=cpu).to(x.device)
|
|
||||||
|
|
||||||
return torch.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
|
||||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
|
||||||
|
|
||||||
Use either randn() or manual_seed() to initialize the generator."""
|
|
||||||
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def manual_seed(seed):
|
|
||||||
"""Set up a global random number generator using the specified seed."""
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
|
||||||
global nv_rng
|
|
||||||
nv_rng = rng_philox.Generator(seed)
|
|
||||||
return
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
@ -236,3 +158,4 @@ def first_time_calculation():
|
||||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||||
conv2d(x)
|
conv2d(x)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from skimage import exposure
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
@ -186,6 +186,7 @@ class StableDiffusionProcessing:
|
||||||
self.cached_c = StableDiffusionProcessing.cached_c
|
self.cached_c = StableDiffusionProcessing.cached_c
|
||||||
self.uc = None
|
self.uc = None
|
||||||
self.c = None
|
self.c = None
|
||||||
|
self.rng: rng.ImageRNG = None
|
||||||
|
|
||||||
self.user = None
|
self.user = None
|
||||||
|
|
||||||
|
@ -475,82 +476,9 @@ class Processed:
|
||||||
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
||||||
|
|
||||||
|
|
||||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
|
||||||
def slerp(val, low, high):
|
|
||||||
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
|
||||||
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
|
||||||
dot = (low_norm*high_norm).sum(1)
|
|
||||||
|
|
||||||
if dot.mean() > 0.9995:
|
|
||||||
return low * val + high * (1 - val)
|
|
||||||
|
|
||||||
omega = torch.acos(dot)
|
|
||||||
so = torch.sin(omega)
|
|
||||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
||||||
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
|
||||||
xs = []
|
return g.next()
|
||||||
|
|
||||||
# if we have multiple seeds, this means we are working with batch size>1; this then
|
|
||||||
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
|
||||||
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
|
||||||
# produce the same images as with two batches [100], [101].
|
|
||||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
|
||||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
|
||||||
else:
|
|
||||||
sampler_noises = None
|
|
||||||
|
|
||||||
for i, seed in enumerate(seeds):
|
|
||||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
|
||||||
|
|
||||||
subnoise = None
|
|
||||||
if subseeds is not None and subseed_strength != 0:
|
|
||||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
|
||||||
|
|
||||||
subnoise = devices.randn(subseed, noise_shape)
|
|
||||||
|
|
||||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
|
||||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
|
||||||
# but the original script had it like this, so I do not dare change it for now because
|
|
||||||
# it will break everyone's seeds.
|
|
||||||
noise = devices.randn(seed, noise_shape)
|
|
||||||
|
|
||||||
if subnoise is not None:
|
|
||||||
noise = slerp(subseed_strength, noise, subnoise)
|
|
||||||
|
|
||||||
if noise_shape != shape:
|
|
||||||
x = devices.randn(seed, shape)
|
|
||||||
dx = (shape[2] - noise_shape[2]) // 2
|
|
||||||
dy = (shape[1] - noise_shape[1]) // 2
|
|
||||||
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
|
||||||
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
|
||||||
tx = 0 if dx < 0 else dx
|
|
||||||
ty = 0 if dy < 0 else dy
|
|
||||||
dx = max(-dx, 0)
|
|
||||||
dy = max(-dy, 0)
|
|
||||||
|
|
||||||
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
|
||||||
noise = x
|
|
||||||
|
|
||||||
if sampler_noises is not None:
|
|
||||||
cnt = p.sampler.number_of_needed_noises(p)
|
|
||||||
|
|
||||||
if eta_noise_seed_delta > 0:
|
|
||||||
devices.manual_seed(seed + eta_noise_seed_delta)
|
|
||||||
|
|
||||||
for j in range(cnt):
|
|
||||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
|
||||||
|
|
||||||
xs.append(noise)
|
|
||||||
|
|
||||||
if sampler_noises is not None:
|
|
||||||
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
|
||||||
|
|
||||||
x = torch.stack(xs).to(shared.device)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DecodedSamples(list):
|
class DecodedSamples(list):
|
||||||
|
@ -769,6 +697,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||||
|
|
||||||
|
@ -1072,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = self.rng.next()
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
del x
|
del x
|
||||||
|
|
||||||
|
@ -1160,7 +1090,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
|
||||||
|
noise = self.rng.next()
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -1418,7 +1349,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = self.rng.next()
|
||||||
|
|
||||||
if self.initial_noise_multiplier != 1.0:
|
if self.initial_noise_multiplier != 1.0:
|
||||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import devices, rng_philox, shared
|
||||||
|
|
||||||
|
|
||||||
|
def randn(seed, shape, generator=None):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||||
|
|
||||||
|
manual_seed(seed)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||||
|
|
||||||
|
return torch.randn(shape, device=devices.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_local(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
rng = rng_philox.Generator(seed)
|
||||||
|
return torch.asarray(rng.randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||||
|
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||||
|
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_like(x):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||||
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||||
|
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_without_seed(shape, generator=None):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||||
|
|
||||||
|
return torch.randn(shape, device=devices.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed):
|
||||||
|
"""Set up a global random number generator using the specified seed."""
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
global nv_rng
|
||||||
|
nv_rng = rng_philox.Generator(seed)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def create_generator(seed):
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return rng_philox.Generator(seed)
|
||||||
|
|
||||||
|
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||||
|
def slerp(val, low, high):
|
||||||
|
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||||
|
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||||
|
dot = (low_norm*high_norm).sum(1)
|
||||||
|
|
||||||
|
if dot.mean() > 0.9995:
|
||||||
|
return low * val + high * (1 - val)
|
||||||
|
|
||||||
|
omega = torch.acos(dot)
|
||||||
|
so = torch.sin(omega)
|
||||||
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRNG:
|
||||||
|
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||||
|
self.shape = shape
|
||||||
|
self.seeds = seeds
|
||||||
|
self.subseeds = subseeds
|
||||||
|
self.subseed_strength = subseed_strength
|
||||||
|
self.seed_resize_from_h = seed_resize_from_h
|
||||||
|
self.seed_resize_from_w = seed_resize_from_w
|
||||||
|
|
||||||
|
self.generators = [create_generator(seed) for seed in seeds]
|
||||||
|
|
||||||
|
self.is_first = True
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
|
||||||
|
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
|
||||||
|
subnoise = None
|
||||||
|
if self.subseeds is not None and self.subseed_strength != 0:
|
||||||
|
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
|
||||||
|
subnoise = randn(subseed, noise_shape)
|
||||||
|
|
||||||
|
if noise_shape != self.shape:
|
||||||
|
noise = randn(seed, noise_shape)
|
||||||
|
else:
|
||||||
|
noise = randn(seed, self.shape, generator=generator)
|
||||||
|
|
||||||
|
if subnoise is not None:
|
||||||
|
noise = slerp(self.subseed_strength, noise, subnoise)
|
||||||
|
|
||||||
|
if noise_shape != self.shape:
|
||||||
|
x = randn(seed, self.shape, generator=generator)
|
||||||
|
dx = (self.shape[2] - noise_shape[2]) // 2
|
||||||
|
dy = (self.shape[1] - noise_shape[1]) // 2
|
||||||
|
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||||
|
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||||
|
tx = 0 if dx < 0 else dx
|
||||||
|
ty = 0 if dy < 0 else dy
|
||||||
|
dx = max(-dx, 0)
|
||||||
|
dy = max(-dy, 0)
|
||||||
|
|
||||||
|
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
|
||||||
|
noise = x
|
||||||
|
|
||||||
|
xs.append(noise)
|
||||||
|
|
||||||
|
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
|
||||||
|
if eta_noise_seed_delta:
|
||||||
|
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
|
||||||
|
|
||||||
|
return torch.stack(xs).to(shared.device)
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if self.is_first:
|
||||||
|
self.is_first = False
|
||||||
|
return self.first()
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
for generator in self.generators:
|
||||||
|
x = randn_without_seed(self.shape, generator=generator)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
return torch.stack(xs).to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
devices.randn = randn
|
||||||
|
devices.randn_local = randn_local
|
||||||
|
devices.randn_like = randn_like
|
||||||
|
devices.randn_without_seed = randn_without_seed
|
||||||
|
devices.manual_seed = manual_seed
|
|
@ -1,5 +1,5 @@
|
||||||
import inspect
|
import inspect
|
||||||
from collections import namedtuple, deque
|
from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -132,10 +132,15 @@ replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
class TorchHijack:
|
||||||
def __init__(self, sampler_noises):
|
"""This is here to replace torch.randn_like of k-diffusion.
|
||||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
|
||||||
# implementation.
|
k-diffusion has random_sampler argument for most samplers, but not for all, so
|
||||||
self.sampler_noises = deque(sampler_noises)
|
this is needed to properly replace every use of torch.randn_like.
|
||||||
|
|
||||||
|
We need to replace to make images generated in batches to be same as images generated individually."""
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
self.rng = p.rng
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item == 'randn_like':
|
if item == 'randn_like':
|
||||||
|
@ -147,12 +152,7 @@ class TorchHijack:
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
def randn_like(self, x):
|
def randn_like(self, x):
|
||||||
if self.sampler_noises:
|
return self.rng.next()
|
||||||
noise = self.sampler_noises.popleft()
|
|
||||||
if noise.shape == x.shape:
|
|
||||||
return noise
|
|
||||||
|
|
||||||
return devices.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
|
@ -215,7 +215,7 @@ class Sampler:
|
||||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
k_diffusion.sampling.torch = TorchHijack(p)
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
extra_params_kwargs = {}
|
||||||
for param_name in self.extra_params:
|
for param_name in self.extra_params:
|
||||||
|
|
|
@ -16,7 +16,7 @@ import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args, rng # noqa: F401
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
Loading…
Reference in New Issue