Access Underlying Model

This commit is contained in:
cafeai 2022-12-03 19:57:00 +09:00
parent 3cefb57fc6
commit 34715bcc97
1 changed files with 3 additions and 0 deletions

View File

@ -500,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset):
self.device = device self.device = device
self.ucg = ucg self.ucg = ucg
if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel:
self.text_encoder = self.text_encoder.module
self.transforms = torchvision.transforms.Compose([ self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),