From c690d005bcf8a139e297d32d66dd9a5e85c02c80 Mon Sep 17 00:00:00 2001 From: harubaru Date: Tue, 20 Sep 2022 22:24:04 -0700 Subject: [PATCH] fix cropping --- .../v1-finetune-danbooru-8gpu.yaml | 1 - ldm/data/localdanbooru.py | 19 +++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml b/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml index 659a6c3..22e6194 100644 --- a/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml +++ b/configs/stable-diffusion/v1-finetune-danbooru-8gpu.yaml @@ -84,7 +84,6 @@ data: normalize: true caption_shuffle: true - lightning: modelcheckpoint: params: diff --git a/ldm/data/localdanbooru.py b/ldm/data/localdanbooru.py index c9457a8..7617057 100644 --- a/ldm/data/localdanbooru.py +++ b/ldm/data/localdanbooru.py @@ -59,7 +59,17 @@ class CaptionProcessor(object): # preprocess image image = sample['image'] - image = self.transforms(Image.open(io.BytesIO(image))) + + image = Image.open(io.BytesIO(image)) + + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + image = Image.fromarray(img) + + image = self.transforms(image) image = np.array(image).astype(np.uint8) sample['image'] = (image / 127.5 - 1.0).astype(np.float32) return sample @@ -119,9 +129,7 @@ class DanbooruWebDataModuleFromConfig(pl.LightningDataModule): def make_loader(self, dataset_config, train=True): image_transforms = [] - image_transforms.extend([torchvision.transforms.CenterCrop(self.size), - torchvision.transforms.Resize(self.size), - torchvision.transforms.RandomHorizontalFlip(self.flip_p)],) + image_transforms.extend([torchvision.transforms.Resize(self.size), torchvision.transforms.RandomHorizontalFlip(self.flip_p)],) image_transforms = torchvision.transforms.Compose(image_transforms) transform_dict = {} @@ -177,6 +185,9 @@ def example(): for batch in dataloader: print(batch["image"].shape) print(batch['caption']) + image = ((batch["image"][0] + 1) * 127.5).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save('example.png') break if __name__ == '__main__':