fix cropping

This commit is contained in:
harubaru 2022-09-20 22:24:04 -07:00
parent 01b9440d50
commit c690d005bc
2 changed files with 15 additions and 5 deletions

View File

@ -84,7 +84,6 @@ data:
normalize: true
caption_shuffle: true
lightning:
modelcheckpoint:
params:

View File

@ -59,7 +59,17 @@ class CaptionProcessor(object):
# preprocess image
image = sample['image']
image = self.transforms(Image.open(io.BytesIO(image)))
image = Image.open(io.BytesIO(image))
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
image = self.transforms(image)
image = np.array(image).astype(np.uint8)
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
return sample
@ -119,9 +129,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule):
def make_loader(self, dataset_config, train=True):
image_transforms = []
image_transforms.extend([torchvision.transforms.CenterCrop(self.size),
torchvision.transforms.Resize(self.size),
torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],)
image_transforms = torchvision.transforms.Compose(image_transforms)
transform_dict = {}
@ -177,6 +185,9 @@ def example():
for batch in dataloader:
print(batch["image"].shape)
print(batch['caption'])
image = ((batch["image"][0] + 1) * 127.5).numpy().astype(np.uint8)
image = Image.fromarray(image)
image.save('example.png')
break
if __name__ == '__main__':