diff --git a/diffusers_trainer.py b/diffusers_trainer.py index b51c317..e317705 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -38,6 +38,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image +from PIL import ImageOps from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -81,6 +82,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') +parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.") args = parser.parse_args() def setup(): @@ -162,16 +164,16 @@ class ImageStore: # iterator returns images as PIL images and their index in the store def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]), f + yield Image.open(self.image_files[f]).convert(mode='RGB'), f # get image by index - def get_image(self, index: int) -> Image.Image: - return Image.open(self.image_files[index]) + def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: + return Image.open(self.image_files[ref[0]]).convert(mode='RGB') # gets caption by removing the extension from the filename and replacing it with .txt - def get_caption(self, index: int) -> str: - filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt' - with open(filename, 'r') as f: + def get_caption(self, ref: Tuple[int, int, int]) -> str: + filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt' + with open(filename, 'r', encoding='UTF-8') as f: return f.read() @@ -269,12 +271,12 @@ class AspectBucket: def get_bucket_info(self): return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios }) - def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: + def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]: """ Generator that provides batches where the images in a batch fall on the same bucket Each element generated will be: - ((w, h), [image1, image2, ..., image{batch_size}]) + (index, w, h) where each image is an index into the dataset :return: @@ -318,7 +320,7 @@ class AspectBucket: total_generated_by_bucket[b] += self.batch_size bucket_pos[b] = i - yield [idx for idx in batch] + yield [(idx, *b) for idx in batch] def fill_buckets(self): entries = self.store.entries_iterator() @@ -383,16 +385,26 @@ class AspectDataset(torch.utils.data.Dataset): self.transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.5], [0.5]), + torchvision.transforms.Normalize([0.5], [0.5]) ]) def __len__(self): return len(self.store) - def __getitem__(self, item: int): + def __getitem__(self, item: Tuple[int, int, int]): return_dict = {'pixel_values': None, 'input_ids': None} image_file = self.store.get_image(item) + + if args.resize: + image_file = ImageOps.fit( + image_file, + (item[1], item[2]), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ) + return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -576,7 +588,6 @@ def main(): ) # load dataset - store = ImageStore(args.dataset) dataset = AspectDataset(store, tokenizer) bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)