fix multiplier issues with validation and refactor validation logic
This commit is contained in:
parent
1068d2dd2a
commit
4e37200dda
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
import bisect
|
||||
import logging
|
||||
from functools import reduce
|
||||
import math
|
||||
import copy
|
||||
|
@ -41,40 +42,27 @@ class DataLoaderMultiAspect():
|
|||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
self.__update_rating_sums()
|
||||
count_including_multipliers = sum([math.floor(max(i.multiplier, 1)) for i in self.prepared_train_data])
|
||||
if count_including_multipliers > len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} items ({count_including_multipliers} items total after applying multipliers)")
|
||||
else:
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} items")
|
||||
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
"""
|
||||
Deals with multiply.txt whole and fractional numbers
|
||||
"""
|
||||
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
|
||||
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
|
||||
epoch_size = len(self.prepared_train_data)
|
||||
picked_images = []
|
||||
|
||||
# add by whole number part first and decrement multiplier in copy
|
||||
for iti in data_copy:
|
||||
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
|
||||
while iti.multiplier >= 1.0:
|
||||
for iti in self.prepared_train_data:
|
||||
multiplier = iti.multiplier
|
||||
while multiplier >= 1:
|
||||
picked_images.append(iti)
|
||||
multiplier -= 1
|
||||
# deal with fractional remainder
|
||||
if multiplier > randomizer.uniform(0, 1):
|
||||
picked_images.append(iti)
|
||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
|
||||
iti.multiplier -= 1.0
|
||||
|
||||
remaining = epoch_size - len(picked_images)
|
||||
|
||||
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
||||
|
||||
# add by remaining fractional numbers by random chance
|
||||
while remaining > 0:
|
||||
for iti in data_copy:
|
||||
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
|
||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
|
||||
picked_images.append(iti)
|
||||
remaining -= 1
|
||||
iti.multiplier = 0.0
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
del data_copy
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
||||
|
|
|
@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset):
|
|||
num_images = len(self.image_train_items)
|
||||
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
|
||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
||||
self.__update_image_train_items(1.0)
|
||||
return items
|
||||
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Callable, Any, Optional, Generator
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
|
@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch
|
|||
remaining_items = list(items_copy[split_item_count:])
|
||||
return split_items, remaining_items
|
||||
|
||||
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
|
||||
for i in items:
|
||||
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
default_batch_size: int,
|
||||
resolution: int,
|
||||
log_writer: SummaryWriter):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
|
||||
self.log_writer = log_writer
|
||||
self.resolution = resolution
|
||||
|
||||
self.config = {
|
||||
'batch_size': default_batch_size,
|
||||
'every_n_epochs': 1,
|
||||
'seed': 555,
|
||||
|
||||
'validate_training': True,
|
||||
'val_split_mode': 'automatic',
|
||||
'val_split_proportion': 0.15,
|
||||
|
||||
|
@ -120,21 +128,24 @@ class EveryDreamValidator:
|
|||
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config['val_split_mode']
|
||||
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
||||
val_split_proportion = self.config['val_split_proportion']
|
||||
remaining_train_items = image_train_items
|
||||
if val_split_mode == 'none':
|
||||
if val_split_mode is None or val_split_mode == 'none':
|
||||
return None, image_train_items
|
||||
elif val_split_mode == 'automatic':
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||
val_items = list(disable_multiplier_and_flip(val_items))
|
||||
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||
elif val_split_mode == 'manual':
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(512),
|
||||
aspects=aspects.get_aspect_buckets(self.resolution),
|
||||
flip_p=0.0,
|
||||
seed=self.seed,
|
||||
)
|
||||
val_data_root = self.config['val_data_root']
|
||||
val_items = resolver.resolve_root(val_data_root, args)
|
||||
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
|
@ -149,6 +160,7 @@ class EveryDreamValidator:
|
|||
|
||||
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||
name='stabilize-train')
|
||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
|
|
|
@ -263,7 +263,7 @@ class ImageTrainItem:
|
|||
self.multiplier = multiplier
|
||||
|
||||
self.image_size = None
|
||||
if image is None:
|
||||
if image is None or len(image) == 0:
|
||||
self.image = []
|
||||
else:
|
||||
self.image = image
|
||||
|
|
|
@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver):
|
|||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
multipliers[current_dir] = 1.0
|
||||
|
@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver):
|
|||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
|
||||
cur_file_multiplier = multipliers[current_dir]
|
||||
items.append(item)
|
||||
|
||||
while cur_file_multiplier >= 1.0:
|
||||
items.append(item)
|
||||
cur_file_multiplier -= 1
|
||||
|
||||
if cur_file_multiplier > 0:
|
||||
if randomizer.random() < cur_file_multiplier:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
|
|
17
train.py
17
train.py
|
@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
|
|||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
|
||||
|
@ -326,8 +327,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
|
|||
|
||||
# Remove erroneous items
|
||||
image_train_items = [item for item in resolved_items if item.error is None]
|
||||
|
||||
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
||||
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
|
||||
|
||||
return image_train_items
|
||||
|
||||
|
@ -620,9 +620,13 @@ def main(args):
|
|||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||
validator = EveryDreamValidator(args.validation_config,
|
||||
default_batch_size=args.batch_size,
|
||||
resolution=args.resolution,
|
||||
log_writer=log_writer,
|
||||
)
|
||||
# the validation dataset may need to steal some items from image_train_items
|
||||
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
|
@ -940,7 +944,7 @@ def main(args):
|
|||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
# validate
|
||||
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
@ -1021,6 +1025,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
|
Loading…
Reference in New Issue