Bucket Helpers

This commit is contained in:
cafeai 2022-11-05 22:55:56 +09:00
parent 6b73cfe448
commit 305efa20f2
1 changed files with 27 additions and 7 deletions

View File

@ -19,6 +19,7 @@ import gc
import time import time
import itertools import itertools
import numpy as np import numpy as np
import json
try: try:
pynvml.nvmlInit() pynvml.nvmlInit()
@ -31,6 +32,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image from PIL import Image
from PIL.Image import Resampling
from typing import Dict, List, Generator, Tuple from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
@ -43,6 +45,7 @@ parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory') parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.') parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.') parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.') parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.') parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate') parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
@ -67,6 +70,7 @@ parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
args = parser.parse_args() args = parser.parse_args()
def setup(): def setup():
@ -143,10 +147,19 @@ class ImageStore:
self.image_files = [] self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
def __len__(self) -> int: def __len__(self) -> int:
return len(self.image_files) return len(self.image_files)
def __valid_file(self, f) -> bool:
try:
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
# iterator returns images as PIL images and their index in the store # iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)): for f in range(len(self)):
@ -254,6 +267,9 @@ class AspectBucket:
def get_batch_count(self): def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values()) return sum(len(b) // self.batch_size for b in self.bucket_data.values())
def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
""" """
Generator that provides batches where the images in a batch fall on the same bucket Generator that provides batches where the images in a batch fall on the same bucket
@ -526,11 +542,15 @@ def main():
store = ImageStore(args.dataset) store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer) dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank) sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
print(f'STORE_LEN: {len(store)}') print(f'STORE_LEN: {len(store)}')
if args.output_bucket_info:
print(bucket.get_bucket_info())
exit(0)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
dataset, dataset,
batch_sampler=sampler, batch_sampler=sampler,