From fe30e8942e3c8c0fce3de9c456c57cf8ecb33c8f Mon Sep 17 00:00:00 2001 From: Maw-Fox Date: Tue, 8 Nov 2022 18:17:31 -0700 Subject: [PATCH 1/4] Add resize and optional resize arg --- diffusers_trainer.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index b51c317..e317705 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -38,6 +38,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image +from PIL import ImageOps from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -81,6 +82,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') +parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.") args = parser.parse_args() def setup(): @@ -162,16 +164,16 @@ class ImageStore: # iterator returns images as PIL images and their index in the store def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]), f + yield Image.open(self.image_files[f]).convert(mode='RGB'), f # get image by index - def get_image(self, index: int) -> Image.Image: - return Image.open(self.image_files[index]) + def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: + return Image.open(self.image_files[ref[0]]).convert(mode='RGB') # gets caption by removing the extension from the filename and replacing it with .txt - def get_caption(self, index: int) -> str: - filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt' - with open(filename, 'r') as f: + def get_caption(self, ref: Tuple[int, int, int]) -> str: + filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' + with open(filename, 'r', encoding='UTF-8') as f: return f.read() @@ -269,12 +271,12 @@ class AspectBucket: def get_bucket_info(self): return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios }) - def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: + def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]: """ Generator that provides batches where the images in a batch fall on the same bucket Each element generated will be: - ((w, h), [image1, image2, ..., image{batch_size}]) + (index, w, h) where each image is an index into the dataset :return: @@ -318,7 +320,7 @@ class AspectBucket: total_generated_by_bucket[b] += self.batch_size bucket_pos[b] = i - yield [idx for idx in batch] + yield [(idx, *b) for idx in batch] def fill_buckets(self): entries = self.store.entries_iterator() @@ -383,16 +385,26 @@ class AspectDataset(torch.utils.data.Dataset): self.transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.5], [0.5]), + torchvision.transforms.Normalize([0.5], [0.5]) ]) def __len__(self): return len(self.store) - def __getitem__(self, item: int): + def __getitem__(self, item: Tuple[int, int, int]): return_dict = {'pixel_values': None, 'input_ids': None} image_file = self.store.get_image(item) + + if args.resize: + image_file = ImageOps.fit( + image_file, + (item[1], item[2]), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ) + return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -576,7 +588,6 @@ def main(): ) # load dataset - store = ImageStore(args.dataset) dataset = AspectDataset(store, tokenizer) bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) From 3fe9df14508bfb9fdf4254d06aa12ccc3f047d3d Mon Sep 17 00:00:00 2001 From: Maw-Fox Date: Tue, 8 Nov 2022 20:50:57 -0700 Subject: [PATCH 2/4] Cleanup. --- diffusers_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index e317705..00c076d 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -37,8 +37,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMSc from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from PIL import Image -from PIL import ImageOps +from PIL import Image, ImageOps from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d From 0be39a4887918e9cba1cf02c6ec66499768e62d2 Mon Sep 17 00:00:00 2001 From: Maw-Fox Date: Thu, 10 Nov 2022 05:42:38 -0700 Subject: [PATCH 3/4] Fix depreciated enum --- diffusers_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 00c076d..c5d0bd7 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -401,7 +401,7 @@ class AspectDataset(torch.utils.data.Dataset): (item[1], item[2]), bleed=0.0, centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS + method=Image.Resampling(Image.LANCZOS) ) return_dict['pixel_values'] = self.transforms(image_file) @@ -548,7 +548,7 @@ def main(): if args.resume: args.model = args.resume - + tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token) text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token) vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token) From 46e8d98d2b5fd323a28b73756964f7b634b91af0 Mon Sep 17 00:00:00 2001 From: Maw-Fox Date: Thu, 10 Nov 2022 06:11:09 -0700 Subject: [PATCH 4/4] Revert --- diffusers_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index c5d0bd7..00c076d 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -401,7 +401,7 @@ class AspectDataset(torch.utils.data.Dataset): (item[1], item[2]), bleed=0.0, centering=(0.5, 0.5), - method=Image.Resampling(Image.LANCZOS) + method=Image.Resampling.LANCZOS ) return_dict['pixel_values'] = self.transforms(image_file) @@ -548,7 +548,7 @@ def main(): if args.resume: args.model = args.resume - + tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token) text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token) vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)