Access Underlying Model
This commit is contained in:
parent
3cefb57fc6
commit
34715bcc97
|
@ -500,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.ucg = ucg
|
self.ucg = ucg
|
||||||
|
|
||||||
|
if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel:
|
||||||
|
self.text_encoder = self.text_encoder.module
|
||||||
|
|
||||||
self.transforms = torchvision.transforms.Compose([
|
self.transforms = torchvision.transforms.Compose([
|
||||||
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToTensor(),
|
||||||
|
|
Loading…
Reference in New Issue