fix mem leak on huge data, rework optimizer to separate json, add lion optimizer
This commit is contained in:
parent
a2479cfe1f
commit
a9b0189947
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"amp": true,
|
||||
"batch_size": 10,
|
||||
"ckpt_every_n_minutes": null,
|
||||
"clip_grad_norm": null,
|
||||
"clip_skip": 0,
|
||||
"cond_dropout": 1.00,
|
||||
"data_root": "Q:/mldata/coco2017/train2017",
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
"flip_p": 0.0,
|
||||
"gpuid": 0,
|
||||
"gradient_checkpointing": false,
|
||||
"grad_accum": 1,
|
||||
"logdir": "logs",
|
||||
"log_step": 25,
|
||||
"lowvram": false,
|
||||
"lr": 1.0e-06,
|
||||
"lr_decay_steps": 0,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "coco_test",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "SD15",
|
||||
"run_name": null,
|
||||
"sample_prompts": "sample_prompts.txt",
|
||||
"sample_steps": 300,
|
||||
"save_ckpt_dir": null,
|
||||
"save_ckpts_from_n_epochs": 0,
|
||||
"save_every_n_epochs": 20,
|
||||
"save_optimizer": false,
|
||||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"validation_config": "cfgs/validation_coco.json",
|
||||
"wandb": true,
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.04
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"documentation": {
|
||||
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"seed": "The seed to use when running validation and stabilization passes."
|
||||
},
|
||||
"validate_training":"true",
|
||||
"val_split_mode": "manual",
|
||||
"val_data_root": "Q:/mldata/coco2017/val2017",
|
||||
"val_split_proportion": 0.15,
|
||||
"stabilize_training_loss": true,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
"seed": 555
|
||||
}
|
|
@ -129,9 +129,11 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
||||
|
||||
example["image"] = image_train_tmp.image
|
||||
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
||||
image_train_tmp.image = None # hack for now to avoid memory leak
|
||||
example["caption"] = image_train_tmp.caption
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
|
||||
return example
|
||||
|
||||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"doc": {
|
||||
"optimizer": "adamw, adamw8bit, lion",
|
||||
"optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains",
|
||||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
},
|
||||
"optimizer": "adamw8bit",
|
||||
"lr:": null,
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.01
|
||||
}
|
|
@ -21,9 +21,11 @@
|
|||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "project_abc",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "sd_v1-5_vae",
|
||||
"run_name": null,
|
||||
"sample_prompts": "sample_prompts.txt",
|
||||
"sample_steps": 300,
|
||||
"save_ckpt_dir": null,
|
||||
|
@ -33,7 +35,6 @@
|
|||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"useadam8bit": true,
|
||||
"validation_config": null,
|
||||
"wandb": false,
|
||||
"write_schedule": false,
|
||||
|
|
90
train.py
90
train.py
|
@ -125,12 +125,12 @@ def setup_local_logger(args):
|
|||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
|
@ -173,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
|||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
|
||||
|
||||
def set_args_12gb(args):
|
||||
logging.info(" Setting args to 12GB mode")
|
||||
if not args.gradient_checkpointing:
|
||||
|
@ -182,15 +181,10 @@ def set_args_12gb(args):
|
|||
if args.batch_size != 1:
|
||||
logging.info(" - Overiding batch size to 1")
|
||||
args.batch_size = 1
|
||||
# if args.grad_accum != 1:
|
||||
# logging.info(" Overiding grad accum to 1")
|
||||
args.grad_accum = 1
|
||||
if args.resolution > 512:
|
||||
logging.info(" - Overiding resolution to 512")
|
||||
args.resolution = 512
|
||||
if not args.useadam8bit:
|
||||
logging.info(" - Overiding adam8bit to True")
|
||||
args.useadam8bit = True
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
"""
|
||||
|
@ -236,6 +230,9 @@ def setup_args(args):
|
|||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.useadam8bit:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
args.ckpt_every_n_minutes = 20
|
||||
|
@ -450,7 +447,7 @@ def main(args):
|
|||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
|
||||
if not args.disable_xformers:
|
||||
if (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
|
@ -464,10 +461,6 @@ def main(args):
|
|||
logging.info("xformers disabled, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
if args.disable_textenc_training and args.amp:
|
||||
|
@ -485,36 +478,63 @@ def main(args):
|
|||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args, name=args.run_name)
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
flush_secs=5,
|
||||
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
betas = [0.9, 0.999]
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
epsilon = 1e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
|
||||
else:
|
||||
opt_class = torch.optim.AdamW
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
|
||||
opt_class = torch.optim.AdamW
|
||||
optimizer = None
|
||||
|
||||
optimizer = opt_class(
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
|
||||
# open optimizer.json and override optimizer args
|
||||
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
|
||||
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
|
||||
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
|
||||
optimizer_config = json.load(f)
|
||||
betas = optimizer_config["betas"]
|
||||
epsilon = optimizer_config["epsilon"]
|
||||
weight_decay = optimizer_config["weight_decay"]
|
||||
optimizer_name = optimizer_config["optimizer"]
|
||||
curr_lr = optimizer_config.get("lr", curr_lr)
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
opt_class = Lion
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name in ["adam8bit","adamw8bit"]:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
|
||||
if not optimizer:
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=betas,
|
||||
betas=(betas[0], betas[1]),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay)
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
|
@ -561,10 +581,6 @@ def main(args):
|
|||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args)
|
||||
|
||||
|
||||
def log_args(log_writer, args):
|
||||
arglog = "args:\n"
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
|
@ -920,9 +936,11 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||
argparser.add_argument("--run_name", type=str, required=False, default=None, help="Run name for wandb (child of project name), and comment for tensorboard, (def: None)")
|
||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)")
|
||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
|
@ -933,7 +951,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
|
|
22
utils/gpu.py
22
utils/gpu.py
|
@ -13,13 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from pynvml.smi import nvidia_smi
|
||||
from pynvml.smi import nvidia_smi as smi
|
||||
import pynvml
|
||||
import torch
|
||||
|
||||
class GPU:
|
||||
def __init__(self, device: torch.device):
|
||||
self.nvsmi = nvidia_smi.getInstance()
|
||||
self.nvsmi = smi.getInstance()
|
||||
self.device = device
|
||||
|
||||
def __querythis(self, query):
|
||||
return gpu_query['gpu'][self.device.index]
|
||||
|
||||
def get_gpu_memory(self):
|
||||
"""
|
||||
|
@ -29,4 +33,16 @@ class GPU:
|
|||
#print(gpu_query)
|
||||
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
|
||||
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
|
||||
return gpu_used_mem, gpu_total_mem
|
||||
return gpu_used_mem, gpu_total_mem
|
||||
|
||||
def supports_bfloat16(self):
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)
|
||||
compute_compatibility = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return compute_compatibility[0] >= 8
|
||||
|
||||
def driver_version(self):
|
||||
gpu_query = self.nvsmi.DeviceQuery('driver_version')
|
||||
driver_version = gpu_query['gpu'][self.device.index]['driver_version']
|
||||
return driver_version
|
||||
|
Loading…
Reference in New Issue