Merge pull request #14 from JanGerritsen/rated_dataset
Implemented system to train on a subset of the dataset, favoring higher rated images
This commit is contained in:
commit
6ba710d6f1
|
@ -13,7 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
@ -46,7 +47,6 @@ class DataLoaderMultiAspect():
|
|||
self.log_folder = log_folder
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.runts = []
|
||||
|
||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||
|
@ -56,16 +56,66 @@ class DataLoaderMultiAspect():
|
|||
|
||||
self.__recurse_data_root(self=self, recurse_root=data_root)
|
||||
random.Random(seed).shuffle(self.image_paths)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
|
||||
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||
|
||||
def shuffle(self):
|
||||
self.runts = []
|
||||
self.seed = self.seed + 1
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=self.batch_size, debug_level=0)
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||
"""
|
||||
returns the current list of images including their captions in a randomized order,
|
||||
sorted into buckets with same sized images
|
||||
if dropout_fraction < 1.0, only a subset of the images will be returned
|
||||
:param dropout_fraction: must be between 0.0 and 1.0.
|
||||
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
||||
"""
|
||||
"""
|
||||
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
|
||||
"""
|
||||
# TODO: this is not terribly efficient but at least linear time
|
||||
|
||||
def unzip_all(self, path):
|
||||
self.seed += 1
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
if dropout_fraction < 1.0:
|
||||
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
|
||||
else:
|
||||
picked_images = self.prepared_train_data
|
||||
|
||||
randomizer.shuffle(picked_images)
|
||||
|
||||
buckets = {}
|
||||
batch_size = self.batch_size
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
target_wh = image_caption_pair.target_wh
|
||||
|
||||
if (target_wh[0],target_wh[1]) not in buckets:
|
||||
buckets[(target_wh[0],target_wh[1])] = []
|
||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||
|
||||
if len(buckets) > 1:
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
while len(runt_bucket) < batch_size:
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
image_caption_pairs = []
|
||||
for bucket in buckets:
|
||||
image_caption_pairs.extend(buckets[bucket])
|
||||
|
||||
return image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def unzip_all(path):
|
||||
try:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
|
@ -76,8 +126,16 @@ class DataLoaderMultiAspect():
|
|||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
def get_all_images(self):
|
||||
return self.image_caption_pairs
|
||||
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
|
||||
rating_overall_sum: float = 0.0
|
||||
ratings_summed: list[float] = []
|
||||
for image in self.prepared_train_data:
|
||||
rating_overall_sum += image.caption.rating()
|
||||
ratings_summed.append(rating_overall_sum)
|
||||
|
||||
return rating_overall_sum, ratings_summed
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
|
@ -97,6 +155,7 @@ class DataLoaderMultiAspect():
|
|||
try:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
@ -119,7 +178,7 @@ class DataLoaderMultiAspect():
|
|||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ)
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
|
@ -136,9 +195,9 @@ class DataLoaderMultiAspect():
|
|||
for tag in split_caption:
|
||||
tags.append(tag.strip())
|
||||
|
||||
return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
def __prescan_images(self, image_paths: list, flip_p=0.0):
|
||||
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Create ImageTrainItem objects with metadata for hydration later
|
||||
"""
|
||||
|
@ -177,42 +236,42 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return decorated_image_train_items
|
||||
|
||||
def __bucketize_images(self, prepared_train_data: list, batch_size=1, debug_level=0):
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
|
||||
Picks a random subset of all images
|
||||
- The size of the subset is limited by dropout_faction
|
||||
- The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance
|
||||
:param dropout_fraction: must be between 0.0 and 1.0
|
||||
:param picker: seeded random picker
|
||||
:return: list of picked ImageTrainItem
|
||||
"""
|
||||
# TODO: this is not terribly efficient but at least linear time
|
||||
buckets = {}
|
||||
|
||||
for image_caption_pair in prepared_train_data:
|
||||
image_caption_pair.runt_size = 0
|
||||
target_wh = image_caption_pair.target_wh
|
||||
prepared_train_data = self.prepared_train_data.copy()
|
||||
ratings_summed = self.ratings_summed.copy()
|
||||
rating_overall_sum = self.rating_overall_sum
|
||||
|
||||
if (target_wh[0],target_wh[1]) not in buckets:
|
||||
buckets[(target_wh[0],target_wh[1])] = []
|
||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||
num_images = len(prepared_train_data)
|
||||
num_images_to_pick = math.ceil(num_images * dropout_fraction)
|
||||
num_images_to_pick = max(min(num_images_to_pick, num_images), 0)
|
||||
|
||||
if len(buckets) > 1:
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
while len(runt_bucket) < batch_size:
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
|
||||
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
picked_images: list[ImageTrainItem] = []
|
||||
while num_images_to_pick > len(picked_images):
|
||||
# find random sample in dataset
|
||||
point = picker.uniform(0.0, rating_overall_sum)
|
||||
pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 )
|
||||
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
# pick random sample
|
||||
picked_image = prepared_train_data[pos]
|
||||
picked_images.append(picked_image)
|
||||
|
||||
# flatten the buckets
|
||||
image_caption_pairs = []
|
||||
for bucket in buckets:
|
||||
image_caption_pairs.extend(buckets[bucket])
|
||||
# kick picked item out of data set to not pick it again
|
||||
rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0)
|
||||
ratings_summed.pop(pos)
|
||||
prepared_train_data.pop(pos)
|
||||
|
||||
return image_caption_pairs
|
||||
return picked_images
|
||||
|
||||
@staticmethod
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
|
|
|
@ -51,6 +51,8 @@ class EveryDreamBatch(Dataset):
|
|||
retain_contrast=False,
|
||||
write_schedule=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
@ -66,6 +68,9 @@ class EveryDreamBatch(Dataset):
|
|||
self.write_schedule = write_schedule
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
|
||||
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999)
|
||||
|
@ -81,17 +86,15 @@ class EveryDreamBatch(Dataset):
|
|||
log_folder=self.log_folder,
|
||||
)
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
||||
|
||||
self.num_images = len(self.image_train_items)
|
||||
num_images = len(self.image_train_items)
|
||||
|
||||
self._length = self.num_images
|
||||
|
||||
logging.info(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}")
|
||||
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
if self.write_schedule:
|
||||
self.write_batch_schedule(0)
|
||||
self.__write_batch_schedule(0)
|
||||
|
||||
def write_batch_schedule(self, epoch_n):
|
||||
def __write_batch_schedule(self, epoch_n):
|
||||
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||
for i in range(len(self.image_train_items)):
|
||||
try:
|
||||
|
@ -102,19 +105,23 @@ class EveryDreamBatch(Dataset):
|
|||
def get_runts():
|
||||
return dls.shared_dataloader.runts
|
||||
|
||||
def shuffle(self, epoch_n):
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
if dls.shared_dataloader:
|
||||
dls.shared_dataloader.shuffle()
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
if self.rated_dataset:
|
||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
|
||||
else:
|
||||
raise Exception("No dataloader singleton to shuffle")
|
||||
|
||||
if self.write_schedule:
|
||||
self.write_batch_schedule(epoch_n)
|
||||
self.__write_batch_schedule(epoch_n + 1)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
return len(self.image_train_items)
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
|
|
|
@ -31,7 +31,7 @@ class ImageCaption:
|
|||
Represents the various parts of an image caption
|
||||
"""
|
||||
|
||||
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
"""
|
||||
:param main_prompt: The part of the caption which should always be included
|
||||
:param tags: list of tags to pick from to fill the caption
|
||||
|
@ -40,6 +40,7 @@ class ImageCaption:
|
|||
:param use_weights: if ture, weights are considered when shuffling tags
|
||||
"""
|
||||
self.__main_prompt = main_prompt
|
||||
self.__rating = rating
|
||||
self.__tags = tags
|
||||
self.__tag_weights = tag_weights
|
||||
self.__max_target_length = max_target_length
|
||||
|
@ -50,6 +51,9 @@ class ImageCaption:
|
|||
if use_weights and len(tag_weights) > len(tags):
|
||||
self.__tag_weights = tag_weights[:len(tags)]
|
||||
|
||||
def rating(self) -> float:
|
||||
return self.__rating
|
||||
|
||||
def get_shuffled_caption(self, seed: int) -> str:
|
||||
"""
|
||||
returns the caption a string with a random selection of the tags in random order
|
||||
|
@ -97,15 +101,15 @@ class ImageCaption:
|
|||
return ", ".join(tags)
|
||||
|
||||
|
||||
class ImageTrainItem():
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
image: PIL.Image
|
||||
identifier: caption,
|
||||
target_aspect: (width, height),
|
||||
pathname: path to image file
|
||||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
|
|
|
@ -34,5 +34,7 @@
|
|||
"shuffle_tags": false,
|
||||
"useadam8bit": true,
|
||||
"wandb": false,
|
||||
"write_schedule": false
|
||||
}
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_rate": 50
|
||||
}
|
||||
|
|
23
train.py
23
train.py
|
@ -269,6 +269,11 @@ def setup_args(args):
|
|||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
if args.rated_dataset:
|
||||
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
||||
|
||||
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
||||
|
||||
return args
|
||||
|
||||
def main(args):
|
||||
|
@ -509,6 +514,8 @@ def main(args):
|
|||
log_folder=log_folder,
|
||||
write_schedule=args.write_schedule,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
||||
torch.cuda.benchmark = False
|
||||
|
@ -642,8 +649,8 @@ def main(args):
|
|||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
|
||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
# steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
|
||||
# steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
epoch_times = []
|
||||
|
||||
|
@ -669,13 +676,17 @@ def main(args):
|
|||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
||||
|
||||
loss_log_step = []
|
||||
|
||||
try:
|
||||
for epoch in range(args.max_epochs):
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
steps_pbar.reset()
|
||||
images_per_sec_log_step = []
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
|
@ -810,13 +821,15 @@ def main(args):
|
|||
global_step += 1
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
||||
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
|
||||
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
|
||||
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if epoch < args.max_epochs - 1:
|
||||
train_batch.shuffle(epoch_n=epoch+1)
|
||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
@ -912,6 +925,8 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
||||
args = argparser.parse_args()
|
||||
|
||||
|
|
Loading…
Reference in New Issue