diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 444941a..2171a8d 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -19,6 +19,7 @@ import gc import time import itertools import numpy as np +import json try: pynvml.nvmlInit() @@ -31,6 +32,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image +from PIL.Image import Resampling from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -43,6 +45,7 @@ parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner') parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory') parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.') parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.') +parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.') parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.') parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.') parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate') @@ -67,6 +70,7 @@ parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') +parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') args = parser.parse_args() def setup(): @@ -143,15 +147,24 @@ class ImageStore: self.image_files = [] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] + self.image_files = [x for x in self.image_files if self.__valid_file(x)] def __len__(self) -> int: return len(self.image_files) - + + def __valid_file(self, f) -> bool: + try: + Image.open(f) + return True + except: + print(f'WARNING: Unable to open file: {f}') + return False + # iterator returns images as PIL images and their index in the store def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: for f in range(len(self)): yield Image.open(self.image_files[f]), f - + # get image by index def get_image(self, index: int) -> Image.Image: return Image.open(self.image_files[index]) @@ -177,7 +190,7 @@ class AspectBucket: bucket_side_increment: int = 64, max_image_area: int = 512 * 768, max_ratio: float = 2): - + self.requested_bucket_count = num_buckets self.bucket_length_min = bucket_side_min self.bucket_length_max = bucket_side_max @@ -190,7 +203,7 @@ class AspectBucket: self.max_ratio = float('inf') else: self.max_ratio = max_ratio - + self.store = store self.buckets = [] self._bucket_ratios = [] @@ -198,7 +211,7 @@ class AspectBucket: self.bucket_data: Dict[tuple, List[int]] = dict() self.init_buckets() self.fill_buckets() - + def init_buckets(self): possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment)) possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths) @@ -253,7 +266,10 @@ class AspectBucket: def get_batch_count(self): return sum(len(b) // self.batch_size for b in self.bucket_data.values()) - + + def get_bucket_info(self): + return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios }) + def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: """ Generator that provides batches where the images in a batch fall on the same bucket @@ -526,11 +542,15 @@ def main(): store = ImageStore(args.dataset) dataset = AspectDataset(store, tokenizer) - bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) + bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank) print(f'STORE_LEN: {len(store)}') + if args.output_bucket_info: + print(bucket.get_bucket_info()) + exit(0) + train_dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=sampler,