Bucket Helpers
This commit is contained in:
parent
6b73cfe448
commit
305efa20f2
|
@ -19,6 +19,7 @@ import gc
|
||||||
import time
|
import time
|
||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
|
@ -31,6 +32,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
|
||||||
from typing import Dict, List, Generator, Tuple
|
from typing import Dict, List, Generator, Tuple
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
@ -43,6 +45,7 @@ parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
||||||
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
||||||
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
|
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
|
||||||
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
|
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
|
||||||
|
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
|
||||||
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
|
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
|
||||||
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
|
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
|
||||||
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
|
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
|
||||||
|
@ -67,6 +70,7 @@ parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train
|
||||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||||
|
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
@ -143,15 +147,24 @@ class ImageStore:
|
||||||
|
|
||||||
self.image_files = []
|
self.image_files = []
|
||||||
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
|
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
|
||||||
|
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.image_files)
|
return len(self.image_files)
|
||||||
|
|
||||||
|
def __valid_file(self, f) -> bool:
|
||||||
|
try:
|
||||||
|
Image.open(f)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
print(f'WARNING: Unable to open file: {f}')
|
||||||
|
return False
|
||||||
|
|
||||||
# iterator returns images as PIL images and their index in the store
|
# iterator returns images as PIL images and their index in the store
|
||||||
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||||
for f in range(len(self)):
|
for f in range(len(self)):
|
||||||
yield Image.open(self.image_files[f]), f
|
yield Image.open(self.image_files[f]), f
|
||||||
|
|
||||||
# get image by index
|
# get image by index
|
||||||
def get_image(self, index: int) -> Image.Image:
|
def get_image(self, index: int) -> Image.Image:
|
||||||
return Image.open(self.image_files[index])
|
return Image.open(self.image_files[index])
|
||||||
|
@ -177,7 +190,7 @@ class AspectBucket:
|
||||||
bucket_side_increment: int = 64,
|
bucket_side_increment: int = 64,
|
||||||
max_image_area: int = 512 * 768,
|
max_image_area: int = 512 * 768,
|
||||||
max_ratio: float = 2):
|
max_ratio: float = 2):
|
||||||
|
|
||||||
self.requested_bucket_count = num_buckets
|
self.requested_bucket_count = num_buckets
|
||||||
self.bucket_length_min = bucket_side_min
|
self.bucket_length_min = bucket_side_min
|
||||||
self.bucket_length_max = bucket_side_max
|
self.bucket_length_max = bucket_side_max
|
||||||
|
@ -190,7 +203,7 @@ class AspectBucket:
|
||||||
self.max_ratio = float('inf')
|
self.max_ratio = float('inf')
|
||||||
else:
|
else:
|
||||||
self.max_ratio = max_ratio
|
self.max_ratio = max_ratio
|
||||||
|
|
||||||
self.store = store
|
self.store = store
|
||||||
self.buckets = []
|
self.buckets = []
|
||||||
self._bucket_ratios = []
|
self._bucket_ratios = []
|
||||||
|
@ -198,7 +211,7 @@ class AspectBucket:
|
||||||
self.bucket_data: Dict[tuple, List[int]] = dict()
|
self.bucket_data: Dict[tuple, List[int]] = dict()
|
||||||
self.init_buckets()
|
self.init_buckets()
|
||||||
self.fill_buckets()
|
self.fill_buckets()
|
||||||
|
|
||||||
def init_buckets(self):
|
def init_buckets(self):
|
||||||
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
||||||
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
||||||
|
@ -253,7 +266,10 @@ class AspectBucket:
|
||||||
|
|
||||||
def get_batch_count(self):
|
def get_batch_count(self):
|
||||||
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
||||||
|
|
||||||
|
def get_bucket_info(self):
|
||||||
|
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
|
||||||
|
|
||||||
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generator that provides batches where the images in a batch fall on the same bucket
|
Generator that provides batches where the images in a batch fall on the same bucket
|
||||||
|
@ -526,11 +542,15 @@ def main():
|
||||||
|
|
||||||
store = ImageStore(args.dataset)
|
store = ImageStore(args.dataset)
|
||||||
dataset = AspectDataset(store, tokenizer)
|
dataset = AspectDataset(store, tokenizer)
|
||||||
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||||
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
print(f'STORE_LEN: {len(store)}')
|
print(f'STORE_LEN: {len(store)}')
|
||||||
|
|
||||||
|
if args.output_bucket_info:
|
||||||
|
print(bucket.get_bucket_info())
|
||||||
|
exit(0)
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_sampler=sampler,
|
batch_sampler=sampler,
|
||||||
|
|
Loading…
Reference in New Issue