fix mem leak on huge data, rework optimizer to separate json, add lion optimizer

This commit is contained in:
Victor Hall 2023-02-25 15:05:22 -05:00
parent a2479cfe1f
commit a9b0189947
7 changed files with 157 additions and 41 deletions

44
cfgs/train_coco.json Normal file
View File

@ -0,0 +1,44 @@
{
"amp": true,
"batch_size": 10,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 1.00,
"data_root": "Q:/mldata/coco2017/train2017",
"disable_textenc_training": false,
"disable_xformers": false,
"flip_p": 0.0,
"gpuid": 0,
"gradient_checkpointing": false,
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 1.0e-06,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 30,
"notebook": false,
"optimizer_config": "optimizer.json",
"project_name": "coco_test",
"resolution": 512,
"resume_ckpt": "SD15",
"run_name": null,
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_ckpts_from_n_epochs": 0,
"save_every_n_epochs": 20,
"save_optimizer": false,
"scale_lr": false,
"seed": 555,
"shuffle_tags": false,
"validation_config": "cfgs/validation_coco.json",
"wandb": true,
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.04
}

20
cfgs/validation_coco.json Normal file
View File

@ -0,0 +1,20 @@
{
"documentation": {
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
"seed": "The seed to use when running validation and stabilization passes."
},
"validate_training":"true",
"val_split_mode": "manual",
"val_data_root": "Q:/mldata/coco2017/val2017",
"val_split_proportion": 0.15,
"stabilize_training_loss": true,
"stabilize_split_proportion": 0.15,
"every_n_epochs": 1,
"seed": 555
}

View File

@ -129,9 +129,11 @@ class EveryDreamBatch(Dataset):
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak
example["caption"] = image_train_tmp.caption
example["runt_size"] = image_train_tmp.runt_size
return example
def __update_image_train_items(self, dropout_fraction: float):

15
optimizer.json Normal file
View File

@ -0,0 +1,15 @@
{
"doc": {
"optimizer": "adamw, adamw8bit, lion",
"optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains",
"lr": "learning rate, if null wil use CLI or main JSON config value",
"betas": "exponential decay rates for the moment estimates",
"epsilon": "value added to denominator for numerical stability, unused for lion",
"weight_decay": "weight decay (L2 penalty)"
},
"optimizer": "adamw8bit",
"lr:": null,
"betas": [0.9, 0.999],
"epsilon": 1e-8,
"weight_decay": 0.01
}

View File

@ -21,9 +21,11 @@
"lr_warmup_steps": null,
"max_epochs": 30,
"notebook": false,
"optimizer_config": "optimizer.json",
"project_name": "project_abc",
"resolution": 512,
"resume_ckpt": "sd_v1-5_vae",
"run_name": null,
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
@ -33,7 +35,6 @@
"scale_lr": false,
"seed": 555,
"shuffle_tags": false,
"useadam8bit": true,
"validation_config": null,
"wandb": false,
"write_schedule": false,

View File

@ -125,12 +125,12 @@ def setup_local_logger(args):
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
@ -173,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def set_args_12gb(args):
logging.info(" Setting args to 12GB mode")
if not args.gradient_checkpointing:
@ -182,15 +181,10 @@ def set_args_12gb(args):
if args.batch_size != 1:
logging.info(" - Overiding batch size to 1")
args.batch_size = 1
# if args.grad_accum != 1:
# logging.info(" Overiding grad accum to 1")
args.grad_accum = 1
if args.resolution > 512:
logging.info(" - Overiding resolution to 512")
args.resolution = 512
if not args.useadam8bit:
logging.info(" - Overiding adam8bit to True")
args.useadam8bit = True
def find_last_checkpoint(logdir):
"""
@ -236,6 +230,9 @@ def setup_args(args):
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.useadam8bit:
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
@ -450,7 +447,7 @@ def main(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not args.disable_xformers:
if (args.amp and is_sd1attn) or (not is_sd1attn):
try:
@ -464,10 +461,6 @@ def main(args):
logging.info("xformers disabled, using attention slicing instead")
unet.set_attention_slice("auto")
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
if args.disable_textenc_training and args.amp:
@ -485,36 +478,63 @@ def main(args):
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args, name=args.run_name)
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
flush_secs=5,
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
)
betas = (0.9, 0.999)
betas = [0.9, 0.999]
epsilon = 1e-8
if args.amp:
epsilon = 1e-8
weight_decay = 0.01
if args.useadam8bit:
import bitsandbytes as bnb
opt_class = bnb.optim.AdamW8bit
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
else:
opt_class = torch.optim.AdamW
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
opt_class = torch.optim.AdamW
optimizer = None
optimizer = opt_class(
default_lr = 1e-6
curr_lr = args.lr
# open optimizer.json and override optimizer args
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
optimizer_config = json.load(f)
betas = optimizer_config["betas"]
epsilon = optimizer_config["epsilon"]
weight_decay = optimizer_config["weight_decay"]
optimizer_name = optimizer_config["optimizer"]
curr_lr = optimizer_config.get("lr", curr_lr)
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
if curr_lr is None:
curr_lr = default_lr
if optimizer_name:
if optimizer_name == "lion":
from lion_pytorch import Lion
opt_class = Lion
optimizer = opt_class(
itertools.chain(params_to_train),
lr=curr_lr,
betas=(betas[0], betas[1]),
weight_decay=weight_decay,
)
elif optimizer_name in ["adam8bit","adamw8bit"]:
import bitsandbytes as bnb
opt_class = bnb.optim.AdamW8bit
if not optimizer:
optimizer = opt_class(
itertools.chain(params_to_train),
lr=curr_lr,
betas=betas,
betas=(betas[0], betas[1]),
eps=epsilon,
weight_decay=weight_decay,
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon)
log_optimizer(optimizer, betas, epsilon, weight_decay)
image_train_items = resolve_image_train_items(args, log_folder)
@ -561,10 +581,6 @@ def main(args):
num_training_steps=args.lr_decay_steps,
)
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args)
def log_args(log_writer, args):
arglog = "args:\n"
for arg, value in sorted(vars(args).items()):
@ -920,9 +936,11 @@ if __name__ == "__main__":
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
argparser.add_argument("--run_name", type=str, required=False, default=None, help="Run name for wandb (child of project name), and comment for tensorboard, (def: None)")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
@ -933,7 +951,7 @@ if __name__ == "__main__":
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")

View File

@ -13,13 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pynvml.smi import nvidia_smi
from pynvml.smi import nvidia_smi as smi
import pynvml
import torch
class GPU:
def __init__(self, device: torch.device):
self.nvsmi = nvidia_smi.getInstance()
self.nvsmi = smi.getInstance()
self.device = device
def __querythis(self, query):
return gpu_query['gpu'][self.device.index]
def get_gpu_memory(self):
"""
@ -29,4 +33,16 @@ class GPU:
#print(gpu_query)
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
return gpu_used_mem, gpu_total_mem
return gpu_used_mem, gpu_total_mem
def supports_bfloat16(self):
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)
compute_compatibility = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return compute_compatibility[0] >= 8
def driver_version(self):
gpu_query = self.nvsmi.DeviceQuery('driver_version')
driver_version = gpu_query['gpu'][self.device.index]['driver_version']
return driver_version