From 34715bcc9726f89f81858273800431b9e6522cab Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:57:00 +0900 Subject: [PATCH] Access Underlying Model --- trainer/diffusers_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 21f28e5..64e4021 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -500,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset): self.device = device self.ucg = ucg + if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel: + self.text_encoder = self.text_encoder.module + self.transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.ToTensor(),