Make DDIM and PLMS work on Mac OS

Fix register_buffer error on Mac OS
This commit is contained in:
thesved 2022-11-03 19:44:47 +01:00 committed by GitHub
parent c2465f67db
commit 86b7fc6e5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 1 deletions

View File

@ -1,4 +1,5 @@
import torch
import modules.devices as devices
from einops import repeat
from omegaconf import ListConfig
@ -316,6 +317,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
self.concat_keys = concat_keys
# =================================================================================================
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
# =================================================================================================
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
optimal_type = devices.get_optimal_device()
if attr.device != optimal_type:
if getattr(torch, 'has_mps', False):
attr = attr.to(device="mps", dtype=torch.float32)
else:
attr = attr.to(optimal_type)
setattr(self, name, attr)
def should_hijack_inpainting(checkpoint_info):
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
@ -326,6 +341,8 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer