diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 21f28e5..64e4021 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -500,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset): self.device = device self.ucg = ucg + if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel: + self.text_encoder = self.text_encoder.module + self.transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.ToTensor(),