wip for bf16 testing
This commit is contained in:
parent
2525bb1c5f
commit
3cfecf8729
|
@ -74,12 +74,18 @@ class EveryDreamOptimizer():
|
|||
|
||||
self.load(args.resume_ckpt)
|
||||
|
||||
use_bf16 = torch.cuda.is_bf16_supported()
|
||||
if use_bf16:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
||||
init_scale = 2**17.5
|
||||
|
||||
self.scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
init_scale=2**17.5,
|
||||
enabled=args.amp and not use_bf16, # bfloat16 does not need grad scaler, same dynamic range as fp32
|
||||
init_scale=init_scale,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
growth_interval=50,
|
||||
)
|
||||
|
||||
logging.info(f" Grad scaler enabled: {self.scaler.is_enabled()} (amp mode)")
|
||||
|
@ -369,7 +375,7 @@ class EveryDreamOptimizer():
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(label, optimizer, betas, epsilon, weight_decay, curr_lr)
|
||||
log_optimizer(label, optimizer)
|
||||
return optimizer
|
||||
|
||||
def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
|
||||
|
@ -399,18 +405,34 @@ class EveryDreamOptimizer():
|
|||
return parameters
|
||||
|
||||
|
||||
def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
||||
def log_optimizer(label: str, optimizer: torch.optim.Optimizer):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
# , betas, epsilon, weight_decay, curr_lr
|
||||
all_params = sum([g['params'] for g in optimizer.param_groups], [])
|
||||
frozen_parameter_count = len([p for p in all_params if not p.requires_grad])
|
||||
total_parameter_count = len(all_params)
|
||||
if frozen_parameter_count > 0:
|
||||
param_info = f"({total_parameter_count} parameters, {frozen_parameter_count} frozen)"
|
||||
param_info = f"({total_parameter_count} param_groups, {frozen_parameter_count} frozen)"
|
||||
else:
|
||||
param_info = f"({total_parameter_count} parameters)"
|
||||
param_info = f"({total_parameter_count} param_groups)"
|
||||
|
||||
#try get lr
|
||||
lr = optimizer.param_groups[0].get('lr', None)
|
||||
betas = optimizer.param_groups[0].get('betas', None)
|
||||
epsilon = optimizer.param_groups[0].get('eps', None)
|
||||
weight_decay = optimizer.param_groups[0].get('weight_decay', None)
|
||||
d0 = optimizer.param_groups[0].get('d0', None)
|
||||
|
||||
string_empty = ""
|
||||
log_line = f"{Fore.CYAN} lr: {lr}"
|
||||
log_line += f", betas: {betas}" if betas else string_empty
|
||||
log_line += f", epsilon: {epsilon}" if epsilon else string_empty
|
||||
log_line += f", weight_decay: {weight_decay}" if weight_decay else string_empty
|
||||
log_line += f", d0: {d0}" if d0 else string_empty
|
||||
log_line += f" *{Style.RESET_ALL}"
|
||||
|
||||
logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
logging.info(log_line)
|
||||
|
||||
|
|
22
train.py
22
train.py
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
|
@ -494,10 +495,13 @@ def main(args):
|
|||
logging.info("xformers disabled via arg, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
use_bf16 = torch.cuda.is_bf16_supported()
|
||||
amp_precision = torch.bfloat16 if use_bf16 else torch.float16
|
||||
|
||||
vae = vae.to(device, dtype=amp_precision if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
if args.disable_textenc_training and args.amp:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float16)
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float16 if not use_bf16 else torch.bfloat16)
|
||||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
|
@ -664,9 +668,9 @@ def main(args):
|
|||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, dtype=torch.float16):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
with autocast(enabled=args.amp, dtype=dtype):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
|
@ -704,7 +708,7 @@ def main(args):
|
|||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
with autocast(enabled=args.amp, dtype=dtype):
|
||||
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
|
@ -761,7 +765,7 @@ def main(args):
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, dtype=amp_precision)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
|
@ -806,7 +810,8 @@ def main(args):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
if validator and step in validation_steps:
|
||||
validator.do_validation(global_step, get_model_prediction_and_target)
|
||||
fn = partial(get_model_prediction_and_target, dtype=amp_precision)
|
||||
validator.do_validation(global_step, fn)
|
||||
|
||||
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||
generate_samples(global_step=global_step, batch=batch)
|
||||
|
@ -930,5 +935,6 @@ if __name__ == "__main__":
|
|||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
|
||||
import multiprocessing
|
||||
multiprocessing.set_start_method('spawn')
|
||||
main(args)
|
||||
|
|
|
@ -21,9 +21,6 @@ class GPU:
|
|||
def __init__(self, device: torch.device):
|
||||
self.nvsmi = smi.getInstance()
|
||||
self.device = device
|
||||
|
||||
def __querythis(self, query):
|
||||
return gpu_query['gpu'][self.device.index]
|
||||
|
||||
def get_gpu_memory(self):
|
||||
"""
|
||||
|
@ -39,10 +36,10 @@ class GPU:
|
|||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)
|
||||
compute_compatibility = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
pynvml.nvmlShutdown()
|
||||
return compute_compatibility[0] >= 8
|
||||
|
||||
def driver_version(self):
|
||||
gpu_query = self.nvsmi.DeviceQuery('driver_version')
|
||||
driver_version = gpu_query['gpu'][self.device.index]['driver_version']
|
||||
return driver_version
|
||||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os.path
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from typing import Generator, Callable, Any
|
||||
from typing import Generator, Callable, Any, List
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
@ -74,7 +74,7 @@ class SampleGenerator:
|
|||
num_inference_steps: int = 30
|
||||
random_captions = False
|
||||
|
||||
sample_requests: [str]
|
||||
sample_requests: List[str]
|
||||
log_folder: str
|
||||
log_writer: SummaryWriter
|
||||
|
||||
|
@ -270,7 +270,7 @@ class SampleGenerator:
|
|||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(str(batch[prompt_idx]))
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
tfimage = transforms.ToTensor()(result).float()
|
||||
if batch[prompt_idx].wants_random_caption:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue