Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer into main
This commit is contained in:
commit
f275f20412
|
@ -101,7 +101,7 @@
|
||||||
"!pip install -q protobuf==3.20.1\n",
|
"!pip install -q protobuf==3.20.1\n",
|
||||||
"!pip install -q wandb==0.13.6\n",
|
"!pip install -q wandb==0.13.6\n",
|
||||||
"!pip install -q pyre-extensions==0.0.23\n",
|
"!pip install -q pyre-extensions==0.0.23\n",
|
||||||
"!pip install -q xformers==0.0.17.dev435\n",
|
"!pip install -q xformers==0.0.16\n",
|
||||||
"!pip install -q pytorch-lightning==1.6.5\n",
|
"!pip install -q pytorch-lightning==1.6.5\n",
|
||||||
"!pip install -q OmegaConf==2.2.3\n",
|
"!pip install -q OmegaConf==2.2.3\n",
|
||||||
"!pip install -q numpy==1.23.5\n",
|
"!pip install -q numpy==1.23.5\n",
|
||||||
|
|
|
@ -14,13 +14,14 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
import bisect
|
import bisect
|
||||||
from functools import reduce
|
import logging
|
||||||
|
import os.path
|
||||||
|
from collections import defaultdict
|
||||||
import math
|
import math
|
||||||
import copy
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
from data.image_train_item import ImageTrainItem
|
||||||
import PIL
|
import PIL.Image
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
|
||||||
|
@ -38,43 +39,40 @@ class DataLoaderMultiAspect():
|
||||||
self.prepared_train_data = image_train_items
|
self.prepared_train_data = image_train_items
|
||||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||||
|
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||||
|
if expected_epoch_size != len(self.prepared_train_data):
|
||||||
|
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
|
||||||
|
else:
|
||||||
|
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
|
||||||
|
|
||||||
self.rating_overall_sum: float = 0.0
|
self.rating_overall_sum: float = 0.0
|
||||||
self.ratings_summed: list[float] = []
|
self.ratings_summed: list[float] = []
|
||||||
self.__update_rating_sums()
|
self.__update_rating_sums()
|
||||||
|
|
||||||
def __pick_multiplied_set(self, randomizer):
|
|
||||||
|
def __pick_multiplied_set(self, randomizer: random.Random):
|
||||||
"""
|
"""
|
||||||
Deals with multiply.txt whole and fractional numbers
|
Deals with multiply.txt whole and fractional numbers
|
||||||
"""
|
"""
|
||||||
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
|
|
||||||
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
|
|
||||||
epoch_size = len(self.prepared_train_data)
|
|
||||||
picked_images = []
|
picked_images = []
|
||||||
|
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
|
||||||
# add by whole number part first and decrement multiplier in copy
|
for iti in self.prepared_train_data:
|
||||||
for iti in data_copy:
|
multiplier = iti.multiplier
|
||||||
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
|
while multiplier >= 1:
|
||||||
while iti.multiplier >= 1.0:
|
|
||||||
picked_images.append(iti)
|
picked_images.append(iti)
|
||||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
|
multiplier -= 1
|
||||||
iti.multiplier -= 1.0
|
# fractional remainders must be dealt with separately
|
||||||
|
if multiplier > 0:
|
||||||
|
directory = os.path.dirname(iti.pathname)
|
||||||
|
fractional_images_per_directory[directory].append(iti)
|
||||||
|
|
||||||
remaining = epoch_size - len(picked_images)
|
# resolve fractional parts per-directory
|
||||||
|
for _, fractional_items in fractional_images_per_directory.items():
|
||||||
|
randomizer.shuffle(fractional_items)
|
||||||
|
multiplier = fractional_items[0].multiplier % 1.0
|
||||||
|
count_to_take = math.ceil(multiplier * len(fractional_items))
|
||||||
|
picked_images.extend(fractional_items[:count_to_take])
|
||||||
|
|
||||||
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
|
||||||
|
|
||||||
# add by remaining fractional numbers by random chance
|
|
||||||
while remaining > 0:
|
|
||||||
for iti in data_copy:
|
|
||||||
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
|
|
||||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
|
|
||||||
picked_images.append(iti)
|
|
||||||
remaining -= 1
|
|
||||||
iti.multiplier = 0.0
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
del data_copy
|
|
||||||
return picked_images
|
return picked_images
|
||||||
|
|
||||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
||||||
|
@ -110,20 +108,19 @@ class DataLoaderMultiAspect():
|
||||||
buckets[(target_wh[0],target_wh[1])] = []
|
buckets[(target_wh[0],target_wh[1])] = []
|
||||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||||
|
|
||||||
if len(buckets) > 1:
|
for bucket in buckets:
|
||||||
for bucket in buckets:
|
truncate_count = len(buckets[bucket]) % batch_size
|
||||||
truncate_count = len(buckets[bucket]) % batch_size
|
if truncate_count > 0:
|
||||||
if truncate_count > 0:
|
runt_bucket = buckets[bucket][-truncate_count:]
|
||||||
runt_bucket = buckets[bucket][-truncate_count:]
|
for item in runt_bucket:
|
||||||
for item in runt_bucket:
|
item.runt_size = truncate_count
|
||||||
item.runt_size = truncate_count
|
while len(runt_bucket) < batch_size:
|
||||||
while len(runt_bucket) < batch_size:
|
runt_bucket.append(random.choice(runt_bucket))
|
||||||
runt_bucket.append(random.choice(runt_bucket))
|
|
||||||
|
|
||||||
current_bucket_size = len(buckets[bucket])
|
current_bucket_size = len(buckets[bucket])
|
||||||
|
|
||||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||||
buckets[bucket].extend(runt_bucket)
|
buckets[bucket].extend(runt_bucket)
|
||||||
|
|
||||||
# flatten the buckets
|
# flatten the buckets
|
||||||
items: list[ImageTrainItem] = []
|
items: list[ImageTrainItem] = []
|
||||||
|
|
|
@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset):
|
||||||
num_images = len(self.image_train_items)
|
num_images = len(self.image_train_items)
|
||||||
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||||
|
|
||||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
|
||||||
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
|
||||||
self.__update_image_train_items(1.0)
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||||
self.seed += 1
|
self.seed += 1
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Callable, Any, Optional
|
from typing import Callable, Any, Optional, Generator
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch
|
||||||
remaining_items = list(items_copy[split_item_count:])
|
remaining_items = list(items_copy[split_item_count:])
|
||||||
return split_items, remaining_items
|
return split_items, remaining_items
|
||||||
|
|
||||||
|
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
|
||||||
|
for i in items:
|
||||||
|
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
||||||
|
|
||||||
class EveryDreamValidator:
|
class EveryDreamValidator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
val_config_path: Optional[str],
|
val_config_path: Optional[str],
|
||||||
default_batch_size: int,
|
default_batch_size: int,
|
||||||
|
resolution: int,
|
||||||
log_writer: SummaryWriter):
|
log_writer: SummaryWriter):
|
||||||
self.val_dataloader = None
|
self.val_dataloader = None
|
||||||
self.train_overlapping_dataloader = None
|
self.train_overlapping_dataloader = None
|
||||||
|
|
||||||
self.log_writer = log_writer
|
self.log_writer = log_writer
|
||||||
|
self.resolution = resolution
|
||||||
|
|
||||||
self.config = {
|
self.config = {
|
||||||
'batch_size': default_batch_size,
|
'batch_size': default_batch_size,
|
||||||
'every_n_epochs': 1,
|
'every_n_epochs': 1,
|
||||||
'seed': 555,
|
'seed': 555,
|
||||||
|
|
||||||
|
'validate_training': True,
|
||||||
'val_split_mode': 'automatic',
|
'val_split_mode': 'automatic',
|
||||||
'val_split_proportion': 0.15,
|
'val_split_proportion': 0.15,
|
||||||
|
|
||||||
|
@ -120,21 +128,24 @@ class EveryDreamValidator:
|
||||||
|
|
||||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||||
val_split_mode = self.config['val_split_mode']
|
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
||||||
val_split_proportion = self.config['val_split_proportion']
|
val_split_proportion = self.config['val_split_proportion']
|
||||||
remaining_train_items = image_train_items
|
remaining_train_items = image_train_items
|
||||||
if val_split_mode == 'none':
|
if val_split_mode is None or val_split_mode == 'none':
|
||||||
return None, image_train_items
|
return None, image_train_items
|
||||||
elif val_split_mode == 'automatic':
|
elif val_split_mode == 'automatic':
|
||||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||||
|
val_items = list(disable_multiplier_and_flip(val_items))
|
||||||
|
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||||
elif val_split_mode == 'manual':
|
elif val_split_mode == 'manual':
|
||||||
args = Namespace(
|
args = Namespace(
|
||||||
aspects=aspects.get_aspect_buckets(512),
|
aspects=aspects.get_aspect_buckets(self.resolution),
|
||||||
flip_p=0.0,
|
flip_p=0.0,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
val_data_root = self.config['val_data_root']
|
val_data_root = self.config['val_data_root']
|
||||||
val_items = resolver.resolve_root(val_data_root, args)
|
val_items = resolver.resolve_root(val_data_root, args)
|
||||||
|
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||||
|
@ -149,6 +160,7 @@ class EveryDreamValidator:
|
||||||
|
|
||||||
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||||
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||||
|
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
||||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||||
name='stabilize-train')
|
name='stabilize-train')
|
||||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||||
|
|
|
@ -263,7 +263,7 @@ class ImageTrainItem:
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
|
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
if image is None:
|
if image is None or len(image) == 0:
|
||||||
self.image = []
|
self.image = []
|
||||||
else:
|
else:
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
|
@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver):
|
||||||
with open(multiply_txt_path, 'r') as f:
|
with open(multiply_txt_path, 'r') as f:
|
||||||
val = float(f.read().strip())
|
val = float(f.read().strip())
|
||||||
multipliers[current_dir] = val
|
multipliers[current_dir] = val
|
||||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||||
multipliers[current_dir] = 1.0
|
multipliers[current_dir] = 1.0
|
||||||
|
@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver):
|
||||||
|
|
||||||
caption = ImageCaption.resolve(pathname)
|
caption = ImageCaption.resolve(pathname)
|
||||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||||
|
items.append(item)
|
||||||
cur_file_multiplier = multipliers[current_dir]
|
|
||||||
|
|
||||||
while cur_file_multiplier >= 1.0:
|
|
||||||
items.append(item)
|
|
||||||
cur_file_multiplier -= 1
|
|
||||||
|
|
||||||
if cur_file_multiplier > 0:
|
|
||||||
if randomizer.random() < cur_file_multiplier:
|
|
||||||
items.append(item)
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
51
train.py
51
train.py
|
@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator
|
||||||
from data.image_train_item import ImageTrainItem
|
from data.image_train_item import ImageTrainItem
|
||||||
from utils.huggingface_downloader import try_download_model_from_hf
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
from utils.gpu import GPU
|
if torch.cuda.is_available():
|
||||||
|
from utils.gpu import GPU
|
||||||
import data.aspects as aspects
|
import data.aspects as aspects
|
||||||
import data.resolver as resolver
|
import data.resolver as resolver
|
||||||
|
|
||||||
|
@ -159,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||||
"""
|
"""
|
||||||
updates the vram usage for the epoch
|
updates the vram usage for the epoch
|
||||||
"""
|
"""
|
||||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
if gpu is not None:
|
||||||
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||||
epoch_mem_color = Style.RESET_ALL
|
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
epoch_mem_color = Style.RESET_ALL
|
||||||
epoch_mem_color = Fore.LIGHTRED_EX
|
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
epoch_mem_color = Fore.LIGHTRED_EX
|
||||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||||
|
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||||
|
|
||||||
if logs is not None:
|
if logs is not None:
|
||||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||||
|
|
||||||
|
|
||||||
def set_args_12gb(args):
|
def set_args_12gb(args):
|
||||||
|
@ -326,8 +328,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list
|
||||||
|
|
||||||
# Remove erroneous items
|
# Remove erroneous items
|
||||||
image_train_items = [item for item in resolved_items if item.error is None]
|
image_train_items = [item for item in resolved_items if item.error is None]
|
||||||
|
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
|
||||||
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
|
||||||
|
|
||||||
return image_train_items
|
return image_train_items
|
||||||
|
|
||||||
|
@ -372,6 +373,7 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||||
device = 'cpu'
|
device = 'cpu'
|
||||||
|
gpu = None
|
||||||
|
|
||||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||||
|
|
||||||
|
@ -548,6 +550,7 @@ def main(args):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.error(" * Failed to load checkpoint *")
|
logging.error(" * Failed to load checkpoint *")
|
||||||
|
raise
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
@ -620,9 +623,13 @@ def main(args):
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
|
||||||
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
validator = EveryDreamValidator(args.validation_config,
|
||||||
|
default_batch_size=args.batch_size,
|
||||||
|
resolution=args.resolution,
|
||||||
|
log_writer=log_writer,
|
||||||
|
)
|
||||||
# the validation dataset may need to steal some items from image_train_items
|
# the validation dataset may need to steal some items from image_train_items
|
||||||
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||||
|
|
||||||
data_loader = DataLoaderMultiAspect(
|
data_loader = DataLoaderMultiAspect(
|
||||||
image_train_items=image_train_items,
|
image_train_items=image_train_items,
|
||||||
|
@ -710,8 +717,9 @@ def main(args):
|
||||||
if not os.path.exists(f"{log_folder}/samples/"):
|
if not os.path.exists(f"{log_folder}/samples/"):
|
||||||
os.makedirs(f"{log_folder}/samples/")
|
os.makedirs(f"{log_folder}/samples/")
|
||||||
|
|
||||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
if gpu is not None:
|
||||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||||
|
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||||
|
|
||||||
|
@ -940,7 +948,7 @@ def main(args):
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
|
||||||
# validate
|
# validate
|
||||||
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
@ -1021,6 +1029,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||||
|
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.")
|
||||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||||
|
|
Loading…
Reference in New Issue