From 0d5dc9a6e7f6362e423a06bf0e75dd5854025394 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 08:43:31 +0300 Subject: [PATCH] rework RNG to use generators instead of generating noises beforehand --- modules/devices.py | 81 +--------------- modules/processing.py | 89 ++---------------- modules/rng.py | 171 ++++++++++++++++++++++++++++++++++ modules/sd_samplers_common.py | 24 ++--- modules/shared.py | 2 +- 5 files changed, 196 insertions(+), 171 deletions(-) create mode 100644 modules/rng.py diff --git a/modules/devices.py b/modules/devices.py index 00a00b18a..ce59dc534 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors, rng_philox +from modules import errors if sys.platform == "darwin": from modules import mac_specific @@ -96,84 +96,6 @@ def cond_cast_float(input): nv_rng = None -def randn(seed, shape): - """Generate a tensor with random numbers from a normal distribution using seed. - - Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" - - from modules.shared import opts - - manual_seed(seed) - - if opts.randn_source == "NV": - return torch.asarray(nv_rng.randn(shape), device=device) - - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - - return torch.randn(shape, device=device) - - -def randn_local(seed, shape): - """Generate a tensor with random numbers from a normal distribution using seed. - - Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" - - from modules.shared import opts - - if opts.randn_source == "NV": - rng = rng_philox.Generator(seed) - return torch.asarray(rng.randn(shape), device=device) - - local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device - local_generator = torch.Generator(local_device).manual_seed(int(seed)) - return torch.randn(shape, device=local_device, generator=local_generator).to(device) - - -def randn_like(x): - """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. - - Use either randn() or manual_seed() to initialize the generator.""" - - from modules.shared import opts - - if opts.randn_source == "NV": - return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) - - if opts.randn_source == "CPU" or x.device.type == 'mps': - return torch.randn_like(x, device=cpu).to(x.device) - - return torch.randn_like(x) - - -def randn_without_seed(shape): - """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. - - Use either randn() or manual_seed() to initialize the generator.""" - - from modules.shared import opts - - if opts.randn_source == "NV": - return torch.asarray(nv_rng.randn(shape), device=device) - - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - - return torch.randn(shape, device=device) - - -def manual_seed(seed): - """Set up a global random number generator using the specified seed.""" - from modules.shared import opts - - if opts.randn_source == "NV": - global nv_rng - nv_rng = rng_philox.Generator(seed) - return - - torch.manual_seed(seed) - - def autocast(disable=False): from modules import shared @@ -236,3 +158,4 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + diff --git a/modules/processing.py b/modules/processing.py index aa72b1329..2df5e8c75 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -14,7 +14,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state @@ -186,6 +186,7 @@ class StableDiffusionProcessing: self.cached_c = StableDiffusionProcessing.cached_c self.uc = None self.c = None + self.rng: rng.ImageRNG = None self.user = None @@ -475,82 +476,9 @@ class Processed: return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio -# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 -def slerp(val, low, high): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - dot = (low_norm*high_norm).sum(1) - - if dot.mean() > 0.9995: - return low * val + high * (1 - val) - - omega = torch.acos(dot) - so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high - return res - - def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): - eta_noise_seed_delta = opts.eta_noise_seed_delta or 0 - xs = [] - - # if we have multiple seeds, this means we are working with batch size>1; this then - # enables the generation of additional tensors with noise that the sampler will use during its processing. - # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to - # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0): - sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] - else: - sampler_noises = None - - for i, seed in enumerate(seeds): - noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) - - subnoise = None - if subseeds is not None and subseed_strength != 0: - subseed = 0 if i >= len(subseeds) else subseeds[i] - - subnoise = devices.randn(subseed, noise_shape) - - # randn results depend on device; gpu and cpu get different results for same seed; - # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this, so I do not dare change it for now because - # it will break everyone's seeds. - noise = devices.randn(seed, noise_shape) - - if subnoise is not None: - noise = slerp(subseed_strength, noise, subnoise) - - if noise_shape != shape: - x = devices.randn(seed, shape) - dx = (shape[2] - noise_shape[2]) // 2 - dy = (shape[1] - noise_shape[1]) // 2 - w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx - h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy - tx = 0 if dx < 0 else dx - ty = 0 if dy < 0 else dy - dx = max(-dx, 0) - dy = max(-dy, 0) - - x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] - noise = x - - if sampler_noises is not None: - cnt = p.sampler.number_of_needed_noises(p) - - if eta_noise_seed_delta > 0: - devices.manual_seed(seed + eta_noise_seed_delta) - - for j in range(cnt): - sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) - - xs.append(noise) - - if sampler_noises is not None: - p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises] - - x = torch.stack(xs).to(shared.device) - return x + g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w) + return g.next() class DecodedSamples(list): @@ -769,6 +697,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + if p.scripts is not None: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -1072,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = self.rng.next() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1160,7 +1090,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) + self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w) + noise = self.rng.next() # GC now before running the next img2img to prevent running out of memory devices.torch_gc() @@ -1418,7 +1349,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = self.rng.next() if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier diff --git a/modules/rng.py b/modules/rng.py new file mode 100644 index 000000000..2d7baea5b --- /dev/null +++ b/modules/rng.py @@ -0,0 +1,171 @@ +import torch + +from modules import devices, rng_philox, shared + + +def randn(seed, shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + + manual_seed(seed) + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + if shared.opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=devices.device) + + local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device) + + +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if shared.opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + + return torch.randn_like(x) + + +def randn_without_seed(shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" + from modules.shared import opts + + if opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + +def create_generator(seed): + if shared.opts.randn_source == "NV": + return rng_philox.Generator(seed) + + device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + generator = torch.Generator(device).manual_seed(int(seed)) + return generator + + +# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 +def slerp(val, low, high): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + dot = (low_norm*high_norm).sum(1) + + if dot.mean() > 0.9995: + return low * val + high * (1 - val) + + omega = torch.acos(dot) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res + + +class ImageRNG: + def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): + self.shape = shape + self.seeds = seeds + self.subseeds = subseeds + self.subseed_strength = subseed_strength + self.seed_resize_from_h = seed_resize_from_h + self.seed_resize_from_w = seed_resize_from_w + + self.generators = [create_generator(seed) for seed in seeds] + + self.is_first = True + + def first(self): + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + + xs = [] + + for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)): + subnoise = None + if self.subseeds is not None and self.subseed_strength != 0: + subseed = 0 if i >= len(self.subseeds) else self.subseeds[i] + subnoise = randn(subseed, noise_shape) + + if noise_shape != self.shape: + noise = randn(seed, noise_shape) + else: + noise = randn(seed, self.shape, generator=generator) + + if subnoise is not None: + noise = slerp(self.subseed_strength, noise, subnoise) + + if noise_shape != self.shape: + x = randn(seed, self.shape, generator=generator) + dx = (self.shape[2] - noise_shape[2]) // 2 + dy = (self.shape[1] - noise_shape[1]) // 2 + w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx + h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] + noise = x + + xs.append(noise) + + eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0 + if eta_noise_seed_delta: + self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds] + + return torch.stack(xs).to(shared.device) + + def next(self): + if self.is_first: + self.is_first = False + return self.first() + + xs = [] + for generator in self.generators: + x = randn_without_seed(self.shape, generator=generator) + xs.append(x) + + return torch.stack(xs).to(shared.device) + + +devices.randn = randn +devices.randn_local = randn_local +devices.randn_like = randn_like +devices.randn_without_seed = randn_without_seed +devices.manual_seed = manual_seed diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index adda963bc..97bc08041 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,5 +1,5 @@ import inspect -from collections import namedtuple, deque +from collections import namedtuple import numpy as np import torch from PIL import Image @@ -132,10 +132,15 @@ replace_torchsde_browinan() class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) + """This is here to replace torch.randn_like of k-diffusion. + + k-diffusion has random_sampler argument for most samplers, but not for all, so + this is needed to properly replace every use of torch.randn_like. + + We need to replace to make images generated in batches to be same as images generated individually.""" + + def __init__(self, p): + self.rng = p.rng def __getattr__(self, item): if item == 'randn_like': @@ -147,12 +152,7 @@ class TorchHijack: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) + return self.rng.next() class Sampler: @@ -215,7 +215,7 @@ class Sampler: self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + k_diffusion.sampling.torch = TorchHijack(p) extra_params_kwargs = {} for param_name in self.extra_params: diff --git a/modules/shared.py b/modules/shared.py index e34847cef..e9b980a43 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,7 +16,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args +from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args, rng # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from ldm.models.diffusion.ddpm import LatentDiffusion from typing import Optional