fix cropping
This commit is contained in:
parent
01b9440d50
commit
c690d005bc
|
@ -84,7 +84,6 @@ data:
|
|||
normalize: true
|
||||
caption_shuffle: true
|
||||
|
||||
|
||||
lightning:
|
||||
modelcheckpoint:
|
||||
params:
|
||||
|
|
|
@ -59,7 +59,17 @@ class CaptionProcessor(object):
|
|||
|
||||
# preprocess image
|
||||
image = sample['image']
|
||||
image = self.transforms(Image.open(io.BytesIO(image)))
|
||||
|
||||
image = Image.open(io.BytesIO(image))
|
||||
|
||||
img = np.array(image).astype(np.uint8)
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = img.shape[0], img.shape[1]
|
||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2]
|
||||
image = Image.fromarray(img)
|
||||
|
||||
image = self.transforms(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return sample
|
||||
|
@ -119,9 +129,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
|
|||
|
||||
def make_loader(self, dataset_config, train=True):
|
||||
image_transforms = []
|
||||
image_transforms.extend([torchvision.transforms.CenterCrop(self.size),
|
||||
torchvision.transforms.Resize(self.size),
|
||||
torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||
image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
|
||||
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||
|
||||
transform_dict = {}
|
||||
|
@ -177,6 +185,9 @@ def example():
|
|||
for batch in dataloader:
|
||||
print(batch["image"].shape)
|
||||
print(batch['caption'])
|
||||
image = ((batch["image"][0] + 1) * 127.5).numpy().astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
image.save('example.png')
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue