temporarily disable val, issues
This commit is contained in:
parent
165525a71c
commit
6f4bfdc557
21
train.py
21
train.py
|
@ -277,22 +277,22 @@ def setup_args(args):
|
|||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 250 or (epoch >= 4 and step == 1):
|
||||
if global_step == 250 or (epoch == 4 and step == 1):
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 500 or (epoch >= 8 and step == 1):
|
||||
if global_step == 500 or (epoch == 8 and step == 1):
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000 or (epoch >= 10 and step == 1):
|
||||
if global_step == 1000 or (epoch == 10 and step == 1):
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 3000 or (epoch >= 20 and step == 1):
|
||||
if global_step == 3000 or (epoch == 20 and step == 1):
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
|
@ -366,8 +366,8 @@ def main(args):
|
|||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
gpu = GPU(device)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||
|
@ -620,9 +620,9 @@ def main(args):
|
|||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||
#validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
|
||||
# the validation dataset may need to steal some items from image_train_items
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
#image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
|
@ -940,7 +940,7 @@ def main(args):
|
|||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
# validate
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
#validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
@ -987,7 +987,9 @@ def update_old_args(t_args):
|
|||
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
|
||||
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
|
||||
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
|
||||
|
||||
if not hasattr(t_args, "validation_config"):
|
||||
print(f" Config json is missing 'validation_config'")
|
||||
t_args.__dict__["validation_config"] = None
|
||||
|
||||
if __name__ == "__main__":
|
||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||
|
@ -1043,6 +1045,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--validation_config", type=str, default="validation_default.json", help="validation config file (def: validation_config.json)")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
|
|
|
@ -14,10 +14,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
from pynvml.smi import nvidia_smi
|
||||
import torch
|
||||
|
||||
class GPU:
|
||||
def __init__(self):
|
||||
def __init__(self, device: torch.device):
|
||||
self.nvsmi = nvidia_smi.getInstance()
|
||||
self.device = device
|
||||
|
||||
def get_gpu_memory(self):
|
||||
"""
|
||||
|
@ -25,6 +27,6 @@ class GPU:
|
|||
"""
|
||||
gpu_query = self.nvsmi.DeviceQuery('memory.used, memory.total')
|
||||
#print(gpu_query)
|
||||
gpu_used_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['used'])
|
||||
gpu_total_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['total'])
|
||||
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
|
||||
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
|
||||
return gpu_used_mem, gpu_total_mem
|
Loading…
Reference in New Issue