Merge pull request #9734 from deciare/cpu-randn

Option to make images generated from a given manual seed consistent across CUDA and MPS devices
This commit is contained in:
AUTOMATIC1111 2023-04-29 11:16:06 +03:00 committed by GitHub
commit cb9571e37f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 3 deletions

View File

@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape): def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed) torch.manual_seed(seed)
if device.type == 'mps': if opts.use_cpu_randn or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_without_seed(shape): def randn_without_seed(shape):
if device.type == 'mps': from modules.shared import opts
if opts.use_cpu_randn or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)

View File

@ -60,3 +60,12 @@ def store_latent(decoded):
class InterruptedException(BaseException): class InterruptedException(BaseException):
pass pass
if opts.use_cpu_randn:
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn

View File

@ -190,7 +190,7 @@ class TorchHijack:
if noise.shape == x.shape: if noise.shape == x.shape:
return noise return noise
if x.device.type == 'mps': if opts.use_cpu_randn or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device) return torch.randn_like(x, device=devices.cpu).to(x.device)
else: else:
return torch.randn_like(x) return torch.randn_like(x)

View File

@ -334,6 +334,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"use_cpu_randn": OptionInfo(False, "Use CPU for random number generation to make manual seeds generate the same image across platforms. This may change existing seeds."),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('compatibility', "Compatibility"), {