Merge branch 'main' into main

This commit is contained in:
Anthony Mercurio 2022-11-05 08:55:29 -07:00 committed by GitHub
commit 7be7fc3b98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 2 deletions

View File

@ -23,6 +23,8 @@ import gc
import time
import itertools
import numpy as np
import json
import re
try:
pynvml.nvmlInit()
@ -47,6 +49,7 @@ parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
@ -71,6 +74,7 @@ parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
args = parser.parse_args()
def setup():
@ -147,10 +151,19 @@ class ImageStore:
self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
def __len__(self) -> int:
return len(self.image_files)
def __valid_file(self, f) -> bool:
try:
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)):
@ -162,7 +175,7 @@ class ImageStore:
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, index: int) -> str:
filename = self.image_files[index].split('.')[0] + '.txt'
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt'
with open(filename, 'r') as f:
return f.read()
@ -258,6 +271,9 @@ class AspectBucket:
def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
"""
Generator that provides batches where the images in a batch fall on the same bucket
@ -529,11 +545,15 @@ def main():
store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
print(f'STORE_LEN: {len(store)}')
if args.output_bucket_info:
print(bucket.get_bucket_info())
exit(0)
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,