temporarily disable val, issues

This commit is contained in:
Victor Hall 2023-02-07 18:43:16 -05:00
parent 165525a71c
commit 6f4bfdc557
2 changed files with 17 additions and 12 deletions

View File

@ -277,22 +277,22 @@ def setup_args(args):
return args
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
if global_step == 250 or (epoch >= 4 and step == 1):
if global_step == 250 or (epoch == 4 and step == 1):
factor = 1.8
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 500 or (epoch >= 8 and step == 1):
if global_step == 500 or (epoch == 8 and step == 1):
factor = 1.6
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 1000 or (epoch >= 10 and step == 1):
if global_step == 1000 or (epoch == 10 and step == 1):
factor = 1.3
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
if global_step == 3000 or (epoch >= 20 and step == 1):
if global_step == 3000 or (epoch == 20 and step == 1):
factor = 1.15
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
@ -366,8 +366,8 @@ def main(args):
logging.info(f" Seed: {seed}")
set_seed(seed)
if torch.cuda.is_available():
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
gpu = GPU(device)
torch.backends.cudnn.benchmark = True
else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
@ -620,9 +620,9 @@ def main(args):
image_train_items = resolve_image_train_items(args, log_folder)
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
@ -940,7 +940,7 @@ def main(args):
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch
@ -987,7 +987,9 @@ def update_old_args(t_args):
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
if not hasattr(t_args, "validation_config"):
print(f" Config json is missing 'validation_config'")
t_args.__dict__["validation_config"] = None
if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
@ -1043,6 +1045,7 @@ if __name__ == "__main__":
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--validation_config", type=str, default="validation_default.json", help="validation config file (def: validation_config.json)")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")

View File

@ -14,10 +14,12 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from pynvml.smi import nvidia_smi
import torch
class GPU:
def __init__(self):
def __init__(self, device: torch.device):
self.nvsmi = nvidia_smi.getInstance()
self.device = device
def get_gpu_memory(self):
"""
@ -25,6 +27,6 @@ class GPU:
"""
gpu_query = self.nvsmi.DeviceQuery('memory.used, memory.total')
#print(gpu_query)
gpu_used_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['used'])
gpu_total_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['total'])
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
return gpu_used_mem, gpu_total_mem