add unet ldm in init
This commit is contained in:
parent
b903d3d3c1
commit
4d53a52150
|
@ -7,5 +7,6 @@ __version__ = "0.0.1"
|
|||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_glide import UNetGLIDEModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
|
|
|
@ -18,3 +18,4 @@
|
|||
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import UNetGLIDEModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
|
|
|
@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.dtype_ = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x.type(self.dtype_)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
|
|
Loading…
Reference in New Issue