add unet ldm in init

This commit is contained in:
patil-suraj 2022-06-08 11:44:27 +02:00
parent b903d3d3c1
commit 4d53a52150
3 changed files with 4 additions and 2 deletions

View File

@ -7,5 +7,6 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler

View File

@ -18,3 +18,4 @@
from .unet import UNetModel
from .unet_glide import UNetGLIDEModel
from .unet_ldm import UNetLDMModel

View File

@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
self.dtype_ = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
h = x.type(self.dtype_)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)