temporarily disable val, issues
This commit is contained in:
parent
165525a71c
commit
6f4bfdc557
21
train.py
21
train.py
|
@ -277,22 +277,22 @@ def setup_args(args):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||||
if global_step == 250 or (epoch >= 4 and step == 1):
|
if global_step == 250 or (epoch == 4 and step == 1):
|
||||||
factor = 1.8
|
factor = 1.8
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(50)
|
scaler.set_growth_interval(50)
|
||||||
if global_step == 500 or (epoch >= 8 and step == 1):
|
if global_step == 500 or (epoch == 8 and step == 1):
|
||||||
factor = 1.6
|
factor = 1.6
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(50)
|
scaler.set_growth_interval(50)
|
||||||
if global_step == 1000 or (epoch >= 10 and step == 1):
|
if global_step == 1000 or (epoch == 10 and step == 1):
|
||||||
factor = 1.3
|
factor = 1.3
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
scaler.set_growth_interval(100)
|
scaler.set_growth_interval(100)
|
||||||
if global_step == 3000 or (epoch >= 20 and step == 1):
|
if global_step == 3000 or (epoch == 20 and step == 1):
|
||||||
factor = 1.15
|
factor = 1.15
|
||||||
scaler.set_growth_factor(factor)
|
scaler.set_growth_factor(factor)
|
||||||
scaler.set_backoff_factor(1/factor)
|
scaler.set_backoff_factor(1/factor)
|
||||||
|
@ -366,8 +366,8 @@ def main(args):
|
||||||
logging.info(f" Seed: {seed}")
|
logging.info(f" Seed: {seed}")
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
gpu = GPU()
|
|
||||||
device = torch.device(f"cuda:{args.gpuid}")
|
device = torch.device(f"cuda:{args.gpuid}")
|
||||||
|
gpu = GPU(device)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
else:
|
else:
|
||||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||||
|
@ -620,9 +620,9 @@ def main(args):
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
|
||||||
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||||
# the validation dataset may need to steal some items from image_train_items
|
# the validation dataset may need to steal some items from image_train_items
|
||||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||||
|
|
||||||
data_loader = DataLoaderMultiAspect(
|
data_loader = DataLoaderMultiAspect(
|
||||||
image_train_items=image_train_items,
|
image_train_items=image_train_items,
|
||||||
|
@ -940,7 +940,7 @@ def main(args):
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||||
|
|
||||||
# validate
|
# validate
|
||||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
@ -987,7 +987,9 @@ def update_old_args(t_args):
|
||||||
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
|
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
|
||||||
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
|
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
|
||||||
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
|
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
|
||||||
|
if not hasattr(t_args, "validation_config"):
|
||||||
|
print(f" Config json is missing 'validation_config'")
|
||||||
|
t_args.__dict__["validation_config"] = None
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||||
|
@ -1043,6 +1045,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||||
|
argparser.add_argument("--validation_config", type=str, default="validation_default.json", help="validation config file (def: validation_config.json)")
|
||||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||||
|
|
|
@ -14,10 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
from pynvml.smi import nvidia_smi
|
from pynvml.smi import nvidia_smi
|
||||||
|
import torch
|
||||||
|
|
||||||
class GPU:
|
class GPU:
|
||||||
def __init__(self):
|
def __init__(self, device: torch.device):
|
||||||
self.nvsmi = nvidia_smi.getInstance()
|
self.nvsmi = nvidia_smi.getInstance()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def get_gpu_memory(self):
|
def get_gpu_memory(self):
|
||||||
"""
|
"""
|
||||||
|
@ -25,6 +27,6 @@ class GPU:
|
||||||
"""
|
"""
|
||||||
gpu_query = self.nvsmi.DeviceQuery('memory.used, memory.total')
|
gpu_query = self.nvsmi.DeviceQuery('memory.used, memory.total')
|
||||||
#print(gpu_query)
|
#print(gpu_query)
|
||||||
gpu_used_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['used'])
|
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
|
||||||
gpu_total_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['total'])
|
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
|
||||||
return gpu_used_mem, gpu_total_mem
|
return gpu_used_mem, gpu_total_mem
|
Loading…
Reference in New Issue