From 6f4bfdc557e83be804fd27c0dad824bdd0de2c13 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Tue, 7 Feb 2023 18:43:16 -0500 Subject: [PATCH] temporarily disable val, issues --- train.py | 21 ++++++++++++--------- utils/gpu.py | 8 +++++--- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/train.py b/train.py index 737e5b6..84c9ae9 100644 --- a/train.py +++ b/train.py @@ -277,22 +277,22 @@ def setup_args(args): return args def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): - if global_step == 250 or (epoch >= 4 and step == 1): + if global_step == 250 or (epoch == 4 and step == 1): factor = 1.8 scaler.set_growth_factor(factor) scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(50) - if global_step == 500 or (epoch >= 8 and step == 1): + if global_step == 500 or (epoch == 8 and step == 1): factor = 1.6 scaler.set_growth_factor(factor) scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(50) - if global_step == 1000 or (epoch >= 10 and step == 1): + if global_step == 1000 or (epoch == 10 and step == 1): factor = 1.3 scaler.set_growth_factor(factor) scaler.set_backoff_factor(1/factor) scaler.set_growth_interval(100) - if global_step == 3000 or (epoch >= 20 and step == 1): + if global_step == 3000 or (epoch == 20 and step == 1): factor = 1.15 scaler.set_growth_factor(factor) scaler.set_backoff_factor(1/factor) @@ -366,8 +366,8 @@ def main(args): logging.info(f" Seed: {seed}") set_seed(seed) if torch.cuda.is_available(): - gpu = GPU() device = torch.device(f"cuda:{args.gpuid}") + gpu = GPU(device) torch.backends.cudnn.benchmark = True else: logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.") @@ -620,9 +620,9 @@ def main(args): image_train_items = resolve_image_train_items(args, log_folder) - validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) + #validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) # the validation dataset may need to steal some items from image_train_items - image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) + #image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, @@ -940,7 +940,7 @@ def main(args): log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) # validate - validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) + #validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) gc.collect() # end of epoch @@ -987,7 +987,9 @@ def update_old_args(t_args): if not hasattr(t_args, "rated_dataset_target_dropout_percent"): print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag") t_args.__dict__["rated_dataset_target_dropout_percent"] = 50 - + if not hasattr(t_args, "validation_config"): + print(f" Config json is missing 'validation_config'") + t_args.__dict__["validation_config"] = None if __name__ == "__main__": supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] @@ -1043,6 +1045,7 @@ if __name__ == "__main__": argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") + argparser.add_argument("--validation_config", type=str, default="validation_default.json", help="validation config file (def: validation_config.json)") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") diff --git a/utils/gpu.py b/utils/gpu.py index c37e08c..8bed611 100644 --- a/utils/gpu.py +++ b/utils/gpu.py @@ -14,10 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. """ from pynvml.smi import nvidia_smi +import torch class GPU: - def __init__(self): + def __init__(self, device: torch.device): self.nvsmi = nvidia_smi.getInstance() + self.device = device def get_gpu_memory(self): """ @@ -25,6 +27,6 @@ class GPU: """ gpu_query = self.nvsmi.DeviceQuery('memory.used, memory.total') #print(gpu_query) - gpu_used_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['used']) - gpu_total_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['total']) + gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used']) + gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total']) return gpu_used_mem, gpu_total_mem \ No newline at end of file