Forgot to add train.py earlier 🤦; move write_batch_schedule to train.py
This commit is contained in:
parent
12a0cb6286
commit
56f130c027
|
@ -38,9 +38,7 @@ class EveryDreamBatch(Dataset):
|
|||
crop_jitter=20,
|
||||
seed=555,
|
||||
tokenizer=None,
|
||||
log_folder=None,
|
||||
retain_contrast=False,
|
||||
write_schedule=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5
|
||||
|
@ -52,10 +50,8 @@ class EveryDreamBatch(Dataset):
|
|||
self.crop_jitter = crop_jitter
|
||||
self.unloaded_to_idx = 0
|
||||
self.tokenizer = tokenizer
|
||||
self.log_folder = log_folder
|
||||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.retain_contrast = retain_contrast
|
||||
self.write_schedule = write_schedule
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
|
@ -64,18 +60,7 @@ class EveryDreamBatch(Dataset):
|
|||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
|
||||
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
if self.write_schedule:
|
||||
self.__write_batch_schedule(0)
|
||||
|
||||
def __write_batch_schedule(self, epoch_n):
|
||||
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||
for i in range(len(self.image_train_items)):
|
||||
try:
|
||||
f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n")
|
||||
except Exception as e:
|
||||
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
@ -87,9 +72,6 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
if self.write_schedule:
|
||||
self.__write_batch_schedule(epoch_n + 1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_train_items)
|
||||
|
||||
|
|
78
train.py
78
train.py
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import math
|
||||
import signal
|
||||
|
@ -48,11 +49,15 @@ from accelerate.utils import set_seed
|
|||
|
||||
import wandb
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
||||
from data.every_dream import EveryDreamBatch
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
@ -265,6 +270,8 @@ def setup_args(args):
|
|||
|
||||
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
||||
|
||||
args.aspects = aspects.get_aspect_buckets(args.resolution)
|
||||
|
||||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
|
@ -288,6 +295,35 @@ def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
|||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
|
||||
def report_image_train_item_problems(log_folder, items: list[ImageTrainItem]) -> None:
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {item.error}")
|
||||
|
||||
undersized_items = [item for item in items if item.is_undersized]
|
||||
|
||||
if len(undersized_items) > 0:
|
||||
underized_log_path = os.path.join(log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for undersized_item in undersized_items:
|
||||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
def write_batch_schedule(log_folder, train_batch, epoch):
|
||||
if args.write_schedule:
|
||||
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||
for i in range(len(train_batch.image_train_items)):
|
||||
try:
|
||||
item = train_batch.image_train_items[i]
|
||||
f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n")
|
||||
except Exception as e:
|
||||
logging.error(f" * Error writing to batch schedule for file path: {item.pathname}")
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
|
@ -313,6 +349,8 @@ def main(args):
|
|||
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
args.log_folder = log_folder
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
|
@ -547,23 +585,34 @@ def main(args):
|
|||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
image_train_items = resolver.resolve(args.data_root, args)
|
||||
report_image_train_item_problems(log_folder, image_train_items)
|
||||
image_paths = set(map(lambda item: item.pathname, image_train_items))
|
||||
# Remove erroneous items
|
||||
image_train_items = [item for item in image_train_items if item.error is None]
|
||||
print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
train_batch = EveryDreamBatch(
|
||||
data_root=args.data_root,
|
||||
flip_p=args.flip_p,
|
||||
data_loader=data_loader,
|
||||
debug_level=1,
|
||||
batch_size=args.batch_size,
|
||||
conditional_dropout=args.cond_dropout,
|
||||
resolution=args.resolution,
|
||||
tokenizer=tokenizer,
|
||||
seed = seed,
|
||||
log_folder=log_folder,
|
||||
write_schedule=args.write_schedule,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
||||
|
||||
torch.cuda.benchmark = False
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
@ -589,10 +638,11 @@ def main(args):
|
|||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
log_writer = SummaryWriter(
|
||||
log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
def log_args(log_writer, args):
|
||||
arglog = "args:\n"
|
||||
|
@ -729,6 +779,8 @@ def main(args):
|
|||
# # discard the grads, just want to pin memory
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
write_batch_schedule(log_folder, train_batch, 0)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
|
@ -879,6 +931,7 @@ def main(args):
|
|||
epoch_pbar.update(1)
|
||||
if epoch < args.max_epochs - 1:
|
||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
write_batch_schedule(log_folder, train_batch, epoch + 1)
|
||||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
@ -943,7 +996,6 @@ if __name__ == "__main__":
|
|||
t_args = argparse.Namespace()
|
||||
t_args.__dict__.update(json.load(f))
|
||||
update_old_args(t_args) # update args to support older configs
|
||||
print(f" args: \n{t_args.__dict__}")
|
||||
args = argparser.parse_args(namespace=t_args)
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
|
@ -992,4 +1044,6 @@ if __name__ == "__main__":
|
|||
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
||||
print(f" Args:")
|
||||
pprint.pprint(args.__dict__)
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue