wip for bf16 testing

This commit is contained in:
Victor Hall 2023-06-15 13:54:46 -04:00
parent 2525bb1c5f
commit 3cfecf8729
4 changed files with 48 additions and 23 deletions

View File

@ -74,12 +74,18 @@ class EveryDreamOptimizer():
self.load(args.resume_ckpt)
use_bf16 = torch.cuda.is_bf16_supported()
if use_bf16:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
init_scale = 2**17.5
self.scaler = GradScaler(
enabled=args.amp,
init_scale=2**17.5,
enabled=args.amp and not use_bf16, # bfloat16 does not need grad scaler, same dynamic range as fp32
init_scale=init_scale,
growth_factor=2,
backoff_factor=1.0/2,
growth_interval=25,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {self.scaler.is_enabled()} (amp mode)")
@ -369,7 +375,7 @@ class EveryDreamOptimizer():
amsgrad=False,
)
log_optimizer(label, optimizer, betas, epsilon, weight_decay, curr_lr)
log_optimizer(label, optimizer)
return optimizer
def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
@ -399,18 +405,34 @@ class EveryDreamOptimizer():
return parameters
def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
def log_optimizer(label: str, optimizer: torch.optim.Optimizer):
"""
logs the optimizer settings
"""
# , betas, epsilon, weight_decay, curr_lr
all_params = sum([g['params'] for g in optimizer.param_groups], [])
frozen_parameter_count = len([p for p in all_params if not p.requires_grad])
total_parameter_count = len(all_params)
if frozen_parameter_count > 0:
param_info = f"({total_parameter_count} parameters, {frozen_parameter_count} frozen)"
param_info = f"({total_parameter_count} param_groups, {frozen_parameter_count} frozen)"
else:
param_info = f"({total_parameter_count} parameters)"
param_info = f"({total_parameter_count} param_groups)"
#try get lr
lr = optimizer.param_groups[0].get('lr', None)
betas = optimizer.param_groups[0].get('betas', None)
epsilon = optimizer.param_groups[0].get('eps', None)
weight_decay = optimizer.param_groups[0].get('weight_decay', None)
d0 = optimizer.param_groups[0].get('d0', None)
string_empty = ""
log_line = f"{Fore.CYAN} lr: {lr}"
log_line += f", betas: {betas}" if betas else string_empty
log_line += f", epsilon: {epsilon}" if epsilon else string_empty
log_line += f", weight_decay: {weight_decay}" if weight_decay else string_empty
log_line += f", d0: {d0}" if d0 else string_empty
log_line += f" *{Style.RESET_ALL}"
logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
logging.info(log_line)

View File

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from functools import partial
import os
import pprint
import sys
@ -494,10 +495,13 @@ def main(args):
logging.info("xformers disabled via arg, using attention slicing instead")
unet.set_attention_slice("auto")
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
use_bf16 = torch.cuda.is_bf16_supported()
amp_precision = torch.bfloat16 if use_bf16 else torch.float16
vae = vae.to(device, dtype=amp_precision if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
if args.disable_textenc_training and args.amp:
text_encoder = text_encoder.to(device, dtype=torch.float16)
text_encoder = text_encoder.to(device, dtype=torch.float16 if not use_bf16 else torch.bfloat16)
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
@ -664,9 +668,9 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, dtype=torch.float16):
with torch.no_grad():
with autocast(enabled=args.amp):
with autocast(enabled=args.amp, dtype=dtype):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
@ -704,7 +708,7 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
with autocast(enabled=args.amp, dtype=dtype):
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@ -761,7 +765,7 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, dtype=amp_precision)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@ -806,7 +810,8 @@ def main(args):
torch.cuda.empty_cache()
if validator and step in validation_steps:
validator.do_validation(global_step, get_model_prediction_and_target)
fn = partial(get_model_prediction_and_target, dtype=amp_precision)
validator.do_validation(global_step, fn)
if (global_step + 1) % sample_generator.sample_steps == 0:
generate_samples(global_step=global_step, batch=batch)
@ -930,5 +935,6 @@ if __name__ == "__main__":
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)
import multiprocessing
multiprocessing.set_start_method('spawn')
main(args)

View File

@ -21,9 +21,6 @@ class GPU:
def __init__(self, device: torch.device):
self.nvsmi = smi.getInstance()
self.device = device
def __querythis(self, query):
return gpu_query['gpu'][self.device.index]
def get_gpu_memory(self):
"""
@ -39,10 +36,10 @@ class GPU:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)
compute_compatibility = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
pynvml.nvmlShutdown()
return compute_compatibility[0] >= 8
def driver_version(self):
gpu_query = self.nvsmi.DeviceQuery('driver_version')
driver_version = gpu_query['gpu'][self.device.index]['driver_version']
return driver_version

View File

@ -3,7 +3,7 @@ import logging
import os.path
from dataclasses import dataclass
import random
from typing import Generator, Callable, Any
from typing import Generator, Callable, Any, List
import torch
from PIL import Image, ImageDraw, ImageFont
@ -74,7 +74,7 @@ class SampleGenerator:
num_inference_steps: int = 30
random_captions = False
sample_requests: [str]
sample_requests: List[str]
log_folder: str
log_writer: SummaryWriter
@ -270,7 +270,7 @@ class SampleGenerator:
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
tfimage = transforms.ToTensor()(result).float()
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
else: