fix cropping
This commit is contained in:
parent
01b9440d50
commit
c690d005bc
|
@ -84,7 +84,6 @@ data:
|
||||||
normalize: true
|
normalize: true
|
||||||
caption_shuffle: true
|
caption_shuffle: true
|
||||||
|
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
modelcheckpoint:
|
modelcheckpoint:
|
||||||
params:
|
params:
|
||||||
|
|
|
@ -59,7 +59,17 @@ class CaptionProcessor(object):
|
||||||
|
|
||||||
# preprocess image
|
# preprocess image
|
||||||
image = sample['image']
|
image = sample['image']
|
||||||
image = self.transforms(Image.open(io.BytesIO(image)))
|
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
|
||||||
|
image = self.transforms(image)
|
||||||
image = np.array(image).astype(np.uint8)
|
image = np.array(image).astype(np.uint8)
|
||||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
return sample
|
return sample
|
||||||
|
@ -119,9 +129,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
|
|
||||||
def make_loader(self, dataset_config, train=True):
|
def make_loader(self, dataset_config, train=True):
|
||||||
image_transforms = []
|
image_transforms = []
|
||||||
image_transforms.extend([torchvision.transforms.CenterCrop(self.size),
|
image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||||
torchvision.transforms.Resize(self.size),
|
|
||||||
torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
|
||||||
image_transforms = torchvision.transforms.Compose(image_transforms)
|
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||||
|
|
||||||
transform_dict = {}
|
transform_dict = {}
|
||||||
|
@ -177,6 +185,9 @@ def example():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
print(batch["image"].shape)
|
print(batch["image"].shape)
|
||||||
print(batch['caption'])
|
print(batch['caption'])
|
||||||
|
image = ((batch["image"][0] + 1) * 127.5).numpy().astype(np.uint8)
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
image.save('example.png')
|
||||||
break
|
break
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue