Merge pull request #14 from JanGerritsen/rated_dataset

Implemented system to train on a subset of the dataset, favoring higher rated images
This commit is contained in:
Victor Hall 2023-01-15 19:05:48 -08:00 committed by GitHub
commit 6ba710d6f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 150 additions and 63 deletions

View File

@ -13,7 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
import math
import os
import logging
@ -46,7 +47,6 @@ class DataLoaderMultiAspect():
self.log_folder = log_folder
self.seed = seed
self.batch_size = batch_size
self.runts = []
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
@ -56,16 +56,66 @@ class DataLoaderMultiAspect():
self.__recurse_data_root(self=self, recurse_root=data_root)
random.Random(seed).shuffle(self.image_paths)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
def shuffle(self):
self.runts = []
self.seed = self.seed + 1
random.Random(self.seed).shuffle(self.prepared_train_data)
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=self.batch_size, debug_level=0)
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
"""
returns the current list of images including their captions in a randomized order,
sorted into buckets with same sized images
if dropout_fraction < 1.0, only a subset of the images will be returned
:param dropout_fraction: must be between 0.0 and 1.0.
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
"""
"""
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
"""
# TODO: this is not terribly efficient but at least linear time
def unzip_all(self, path):
self.seed += 1
randomizer = random.Random(self.seed)
if dropout_fraction < 1.0:
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
else:
picked_images = self.prepared_train_data
randomizer.shuffle(picked_images)
buckets = {}
batch_size = self.batch_size
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh
if (target_wh[0],target_wh[1]) not in buckets:
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
if len(buckets) > 1:
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
while len(runt_bucket) < batch_size:
runt_bucket.append(random.choice(runt_bucket))
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# flatten the buckets
image_caption_pairs = []
for bucket in buckets:
image_caption_pairs.extend(buckets[bucket])
return image_caption_pairs
@staticmethod
def unzip_all(path):
try:
for root, dirs, files in os.walk(path):
for file in files:
@ -76,8 +126,16 @@ class DataLoaderMultiAspect():
except Exception as e:
logging.error(f"Error unzipping files {e}")
def get_all_images(self):
return self.image_caption_pairs
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
rating_overall_sum: float = 0.0
ratings_summed: list[float] = []
for image in self.prepared_train_data:
rating_overall_sum += image.caption.rating()
ratings_summed.append(rating_overall_sum)
return rating_overall_sum, ratings_summed
@staticmethod
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
@ -97,6 +155,7 @@ class DataLoaderMultiAspect():
try:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
@ -119,7 +178,7 @@ class DataLoaderMultiAspect():
last_weight = tag_weight
return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ)
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
@ -136,9 +195,9 @@ class DataLoaderMultiAspect():
for tag in split_caption:
tags.append(tag.strip())
return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
def __prescan_images(self, image_paths: list, flip_p=0.0):
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
"""
Create ImageTrainItem objects with metadata for hydration later
"""
@ -177,42 +236,42 @@ class DataLoaderMultiAspect():
return decorated_image_train_items
def __bucketize_images(self, prepared_train_data: list, batch_size=1, debug_level=0):
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance
:param dropout_fraction: must be between 0.0 and 1.0
:param picker: seeded random picker
:return: list of picked ImageTrainItem
"""
# TODO: this is not terribly efficient but at least linear time
buckets = {}
for image_caption_pair in prepared_train_data:
image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh
prepared_train_data = self.prepared_train_data.copy()
ratings_summed = self.ratings_summed.copy()
rating_overall_sum = self.rating_overall_sum
if (target_wh[0],target_wh[1]) not in buckets:
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
num_images = len(prepared_train_data)
num_images_to_pick = math.ceil(num_images * dropout_fraction)
num_images_to_pick = max(min(num_images_to_pick, num_images), 0)
if len(buckets) > 1:
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
while len(runt_bucket) < batch_size:
runt_bucket.append(random.choice(runt_bucket))
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
current_bucket_size = len(buckets[bucket])
picked_images: list[ImageTrainItem] = []
while num_images_to_pick > len(picked_images):
# find random sample in dataset
point = picker.uniform(0.0, rating_overall_sum)
pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 )
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# pick random sample
picked_image = prepared_train_data[pos]
picked_images.append(picked_image)
# flatten the buckets
image_caption_pairs = []
for bucket in buckets:
image_caption_pairs.extend(buckets[bucket])
# kick picked item out of data set to not pick it again
rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0)
ratings_summed.pop(pos)
prepared_train_data.pop(pos)
return image_caption_pairs
return picked_images
@staticmethod
def __recurse_data_root(self, recurse_root):

View File

@ -51,6 +51,8 @@ class EveryDreamBatch(Dataset):
retain_contrast=False,
write_schedule=False,
shuffle_tags=False,
rated_dataset=False,
rated_dataset_dropout_target=0.5
):
self.data_root = data_root
self.batch_size = batch_size
@ -66,6 +68,9 @@ class EveryDreamBatch(Dataset):
self.write_schedule = write_schedule
self.shuffle_tags = shuffle_tags
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
if seed == -1:
seed = random.randint(0, 99999)
@ -81,17 +86,15 @@ class EveryDreamBatch(Dataset):
log_folder=self.log_folder,
)
self.image_train_items = dls.shared_dataloader.get_all_images()
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
self.num_images = len(self.image_train_items)
num_images = len(self.image_train_items)
self._length = self.num_images
logging.info(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}")
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
if self.write_schedule:
self.write_batch_schedule(0)
self.__write_batch_schedule(0)
def write_batch_schedule(self, epoch_n):
def __write_batch_schedule(self, epoch_n):
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(self.image_train_items)):
try:
@ -102,19 +105,23 @@ class EveryDreamBatch(Dataset):
def get_runts():
return dls.shared_dataloader.runts
def shuffle(self, epoch_n):
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
if dls.shared_dataloader:
dls.shared_dataloader.shuffle()
self.image_train_items = dls.shared_dataloader.get_all_images()
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
else:
raise Exception("No dataloader singleton to shuffle")
if self.write_schedule:
self.write_batch_schedule(epoch_n)
self.__write_batch_schedule(epoch_n + 1)
def __len__(self):
return self._length
return len(self.image_train_items)
def __getitem__(self, i):
example = {}

View File

@ -31,7 +31,7 @@ class ImageCaption:
Represents the various parts of an image caption
"""
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
"""
:param main_prompt: The part of the caption which should always be included
:param tags: list of tags to pick from to fill the caption
@ -40,6 +40,7 @@ class ImageCaption:
:param use_weights: if ture, weights are considered when shuffling tags
"""
self.__main_prompt = main_prompt
self.__rating = rating
self.__tags = tags
self.__tag_weights = tag_weights
self.__max_target_length = max_target_length
@ -50,6 +51,9 @@ class ImageCaption:
if use_weights and len(tag_weights) > len(tags):
self.__tag_weights = tag_weights[:len(tags)]
def rating(self) -> float:
return self.__rating
def get_shuffled_caption(self, seed: int) -> str:
"""
returns the caption a string with a random selection of the tags in random order
@ -97,15 +101,15 @@ class ImageCaption:
return ", ".join(tags)
class ImageTrainItem():
class ImageTrainItem:
"""
image: PIL.Image
identifier: caption,
target_aspect: (width, height),
pathname: path to image file
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
self.caption = caption
self.target_wh = target_wh

View File

@ -34,5 +34,7 @@
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": false
}
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_rate": 50
}

View File

@ -269,6 +269,11 @@ def setup_args(args):
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)
if args.rated_dataset:
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
return args
def main(args):
@ -509,6 +514,8 @@ def main(args):
log_folder=log_folder,
write_schedule=args.write_schedule,
shuffle_tags=args.shuffle_tags,
rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
)
torch.cuda.benchmark = False
@ -642,8 +649,8 @@ def main(args):
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
# steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
# steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
epoch_times = []
@ -669,13 +676,17 @@ def main(args):
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
loss_log_step = []
try:
for epoch in range(args.max_epochs):
loss_epoch = []
epoch_start_time = time.time()
steps_pbar.reset()
images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
@ -810,13 +821,15 @@ def main(args):
global_step += 1
# end of step
steps_pbar.close()
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch+1)
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
@ -912,6 +925,8 @@ if __name__ == "__main__":
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
args = argparser.parse_args()