Access Underlying Model
This commit is contained in:
parent
3cefb57fc6
commit
34715bcc97
|
@ -500,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset):
|
|||
self.device = device
|
||||
self.ucg = ucg
|
||||
|
||||
if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel:
|
||||
self.text_encoder = self.text_encoder.module
|
||||
|
||||
self.transforms = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
||||
torchvision.transforms.ToTensor(),
|
||||
|
|
Loading…
Reference in New Issue