Merge pull request #224 from a-l-e-x-d-s-9/main
EMA decay, min-SNR-gamma, settings loading every epoch, zero terminal SNR separation from ZFN.
This commit is contained in:
commit
49de395df1
|
@ -151,10 +151,6 @@ Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/z
|
|||
|
||||
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
|
||||
|
||||
## Zero terminal SNR
|
||||
|
||||
Set `zero_frequency_noise_ratio` to -1.
|
||||
|
||||
## Keeping images together (custom batching)
|
||||
|
||||
If you have a subset of your dataset that expresses the same style or concept, training quality may be improved by putting all of these images through the trainer together in the same batch or batches, instead of the default behaviour (which is to shuffle them randomly throughout the entire dataset).
|
||||
|
@ -205,3 +201,37 @@ Clips the gradient normals to a maximum value. Default is None (no clipping).
|
|||
|
||||
Default is no gradient normal clipping. There are also other ways to deal with gradient explosion, such as increasing optimizer epsilon.
|
||||
|
||||
## Zero Terminal SNR
|
||||
**Parameter:** `--enable_zero_terminal_snr`
|
||||
**Default:** `False`
|
||||
To enable zero terminal SNR.
|
||||
|
||||
## Dynamic Configuration Loading
|
||||
**Parameter:** `--load_settings_every_epoch`
|
||||
**Default:** `False`
|
||||
Most of the parameters in the train.json file CANNOT be modified during training. Activate this to have the `train.json` configuration file reloaded at the start of each epoch. The following parameter can be changed and will be applied after the start of a new epoch:
|
||||
- `--save_every_n_epochs`
|
||||
- `--save_ckpts_from_n_epochs`
|
||||
- `--save_full_precision`
|
||||
- `--save_optimizer`
|
||||
- `--zero_frequency_noise_ratio`
|
||||
- `--min_snr_gamma`
|
||||
- `--clip_skip`
|
||||
|
||||
## Min-SNR-Gamma Parameter
|
||||
**Parameter:** `--min_snr_gamma`
|
||||
**Recommended Values:** 5, 1, 20
|
||||
**Default:** `None`
|
||||
To enable min-SNR-Gamma. For an in-depth understanding, consult this [research paper](https://arxiv.org/abs/2303.09556).
|
||||
|
||||
## EMA Decay Features
|
||||
The Exponential Moving Average (EMA) model is copied from the base model at the start and is updated every interval of steps by a small contribution from training.
|
||||
In this mode, the EMA model will be saved alongside the regular checkpoint from training. Normal training checkpoint can be loaded with `--resume_ckpt`, and the EMA model can be loaded with `--ema_decay_resume_model`.
|
||||
**Parameters:**
|
||||
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
|
||||
- `--ema_decay_target`: Set the EMA decay target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA decay feature.
|
||||
- `--ema_decay_interval`: Set the interval in steps between EMA decay updates. The update occurs at each `global_steps` modulo `decay_interval`.
|
||||
- `--ema_decay_device`: Choose between `cpu` and `cuda` for EMA decay. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
|
||||
- `--ema_decay_sample_raw_training`: Activate to display samples from the trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
|
||||
- `--ema_decay_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
|
||||
- `--ema_decay_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.
|
12
train.json
12
train.json
|
@ -40,5 +40,15 @@
|
|||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.02
|
||||
"zero_frequency_noise_ratio": 0.02,
|
||||
"enable_zero_terminal_snr": false,
|
||||
"load_settings_every_epoch": false,
|
||||
"min_snr_gamma": null,
|
||||
"ema_decay_rate": null,
|
||||
"ema_decay_target": null,
|
||||
"ema_decay_interval": null,
|
||||
"ema_decay_device": null,
|
||||
"ema_decay_sample_raw_training": false,
|
||||
"ema_decay_sample_ema_model": false,
|
||||
"ema_decay_resume_model" : null
|
||||
}
|
||||
|
|
425
train.py
425
train.py
|
@ -61,6 +61,7 @@ from utils.convert_diff_to_ckpt import convert as converter
|
|||
from utils.isolate_rng import isolate_rng
|
||||
from utils.check_git import check_git
|
||||
from optimizer.optimizers import EveryDreamOptimizer
|
||||
from copy import deepcopy
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
|
@ -186,7 +187,7 @@ def set_args_12gb(args):
|
|||
logging.info(" - Overiding resolution to max 512")
|
||||
args.resolution = 512
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
def find_last_checkpoint(logdir, is_ema=False):
|
||||
"""
|
||||
Finds the last checkpoint in the logdir, recursively
|
||||
"""
|
||||
|
@ -196,6 +197,12 @@ def find_last_checkpoint(logdir):
|
|||
for root, dirs, files in os.walk(logdir):
|
||||
for file in files:
|
||||
if os.path.basename(file) == "model_index.json":
|
||||
|
||||
if is_ema and (not root.endswith("_ema")):
|
||||
continue
|
||||
elif (not is_ema) and root.endswith("_ema"):
|
||||
continue
|
||||
|
||||
curr_date = os.path.getmtime(os.path.join(root,file))
|
||||
|
||||
if last_date is None or curr_date > last_date:
|
||||
|
@ -228,6 +235,11 @@ def setup_args(args):
|
|||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if (args.ema_decay_resume_model != None) and (args.ema_decay_resume_model == "findlast"):
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||
|
||||
args.ema_decay_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
|
@ -356,6 +368,75 @@ def log_args(log_writer, args):
|
|||
arglog += f"{arg}={value}, "
|
||||
log_writer.add_text("config", arglog)
|
||||
|
||||
def update_ema(model, ema_model, decay, default_device, ema_device):
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_on_proper_device = model
|
||||
need_to_delete_original = False
|
||||
if ema_device != default_device:
|
||||
original_model_on_other_device = deepcopy(model)
|
||||
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
|
||||
del original_model_on_other_device
|
||||
need_to_delete_original = True
|
||||
|
||||
params = dict(original_model_on_proper_device.named_parameters())
|
||||
ema_params = dict(ema_model.named_parameters())
|
||||
|
||||
for name in ema_params:
|
||||
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
|
||||
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
|
||||
|
||||
if need_to_delete_original:
|
||||
del(original_model_on_proper_device)
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
minimal_value = 1e-9
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
# Use .any() to check if any elements in the tensor are zero
|
||||
if (alphas_cumprod[:-1] == 0).any():
|
||||
logging.warning(
|
||||
f"Alphas cumprod has zero elements! Resetting to {minimal_value}.."
|
||||
)
|
||||
alphas_cumprod[alphas_cumprod[:-1] == 0] = minimal_value
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
||||
timesteps
|
||||
].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
||||
device=timesteps.device
|
||||
)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR, first without epsilon
|
||||
snr = (alpha / sigma) ** 2
|
||||
# Check if the first element in SNR tensor is zero
|
||||
if torch.any(snr == 0):
|
||||
snr[snr == 0] = minimal_value
|
||||
return snr
|
||||
|
||||
def load_train_json_from_file(args, report_load = False):
|
||||
try:
|
||||
if report_load:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
read_json = json.load(f)
|
||||
|
||||
args.__dict__.update(read_json)
|
||||
except Exception as config_read:
|
||||
print(f"Error on loading training config from {args.config}.")
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
|
@ -384,22 +465,86 @@ def main(args):
|
|||
device = 'cpu'
|
||||
gpu = None
|
||||
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
def release_memory(model_to_delete, original_device):
|
||||
del model_to_delete
|
||||
gc.collect()
|
||||
|
||||
if 'cuda' in original_device.type:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
||||
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||
|
||||
nonlocal unet
|
||||
nonlocal text_encoder
|
||||
nonlocal unet_ema
|
||||
nonlocal text_encoder_ema
|
||||
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
||||
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
|
||||
nonlocal save_ckpt_dir
|
||||
nonlocal save_full_precision
|
||||
nonlocal yaml_name
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
|
||||
global global_step
|
||||
|
||||
if global_step is None or global_step == 0:
|
||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||
return
|
||||
logging.info(f" * Saving diffusers model to {save_path}")
|
||||
|
||||
|
||||
if args.ema_decay_rate != None:
|
||||
|
||||
|
||||
pipeline_ema = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder_ema,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet_ema,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
diffusers_model_path = save_path + "_ema"
|
||||
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
|
||||
pipeline_ema.save_pretrained(diffusers_model_path)
|
||||
|
||||
if save_ckpt:
|
||||
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
|
||||
|
||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
||||
|
||||
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -410,30 +555,28 @@ def main(args):
|
|||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
diffusers_model_path = save_path
|
||||
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
|
||||
pipeline.save_pretrained(diffusers_model_path)
|
||||
|
||||
if save_ckpt:
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
half = not save_full_precision
|
||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
if save_optimizer_flag:
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
ed_optimizer.save(save_path)
|
||||
|
||||
|
||||
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
|
||||
ema_decay_model_loaded_from_file = False
|
||||
|
||||
if use_ema_dacay_training:
|
||||
ema_device = torch.device(args.ema_decay_device)
|
||||
|
||||
optimizer_state_path = None
|
||||
try:
|
||||
# check for a local file
|
||||
|
@ -458,16 +601,51 @@ def main(args):
|
|||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# leave the inference scheduler alone
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
if use_ema_dacay_training and args.ema_decay_resume_model:
|
||||
print(f"Loading EMA model: {args.ema_decay_resume_model}")
|
||||
ema_decay_model_loaded_from_file=True
|
||||
hf_cache_path = get_hf_ckpt_cache_path(args.ema_decay_resume_model)
|
||||
|
||||
if args.zero_frequency_noise_ratio == -1.0:
|
||||
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
|
||||
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_decay_resume_model):
|
||||
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
|
||||
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
|
||||
|
||||
else:
|
||||
# try to download from HF using ema_decay_resume_model as a repo id
|
||||
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_decay_resume_model)
|
||||
if ema_downloaded is None:
|
||||
raise ValueError(
|
||||
f"No local file/folder for ema_decay_resume_model {args.ema_decay_resume_model}, and no matching huggingface.co repo could be downloaded")
|
||||
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
|
||||
text_encoder_ema = ema_pipe.text_encoder
|
||||
unet_ema = ema_pipe.unet
|
||||
del ema_pipe
|
||||
|
||||
# Make sure EMA model is on proper device, and memory released if moved
|
||||
unet_ema_current_device = next(unet_ema.parameters()).device
|
||||
if ema_device != unet_ema_current_device:
|
||||
unet_ema_on_wrong_device = unet_ema
|
||||
unet_ema = unet_ema.to(ema_device)
|
||||
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
|
||||
|
||||
# Make sure EMA model is on proper device, and memory released if moved
|
||||
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
|
||||
if ema_device != text_encoder_ema_current_device:
|
||||
text_encoder_ema_on_wrong_device = text_encoder_ema
|
||||
text_encoder_ema = text_encoder_ema.to(ema_device)
|
||||
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
|
||||
|
||||
|
||||
if args.enable_zero_terminal_snr:
|
||||
# Use zero terminal SNR
|
||||
from utils.unet_utils import enforce_zero_terminal_snr
|
||||
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
else:
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||
|
@ -504,6 +682,30 @@ def main(args):
|
|||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
|
||||
if use_ema_dacay_training:
|
||||
if not ema_decay_model_loaded_from_file:
|
||||
logging.info(f"EMA decay enabled, creating EMA model.")
|
||||
|
||||
with torch.no_grad():
|
||||
if args.ema_decay_device == device:
|
||||
unet_ema = deepcopy(unet)
|
||||
text_encoder_ema = deepcopy(text_encoder)
|
||||
else:
|
||||
unet_ema_first = deepcopy(unet)
|
||||
text_encoder_ema_first = deepcopy(text_encoder)
|
||||
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
|
||||
del unet_ema_first
|
||||
del text_encoder_ema_first
|
||||
else:
|
||||
# Make sure correct types are used for models
|
||||
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
|
||||
|
||||
|
||||
try:
|
||||
#unet = torch.compile(unet)
|
||||
#text_encoder = torch.compile(text_encoder)
|
||||
|
@ -575,6 +777,20 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
|
||||
if use_ema_dacay_training:
|
||||
if args.ema_decay_target != None:
|
||||
# Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.
|
||||
total_number_of_steps: float = epoch_len * args.max_epochs
|
||||
total_number_of_ema_update: float = total_number_of_steps / args.ema_decay_interval
|
||||
args.ema_decay_rate = args.ema_decay_target ** (1 / total_number_of_ema_update)
|
||||
|
||||
logging.info(f"ema_decay_target is {args.ema_decay_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
|
||||
|
||||
logging.info(
|
||||
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_decay_interval: {args.ema_decay_interval}, ema_decay_device: {args.ema_decay_device}.")
|
||||
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args,
|
||||
optimizer_config,
|
||||
text_encoder,
|
||||
|
@ -590,7 +806,7 @@ def main(args):
|
|||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2),
|
||||
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
|
||||
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
|
||||
)
|
||||
|
||||
"""
|
||||
|
@ -622,7 +838,9 @@ def main(args):
|
|||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(interrupted_checkpoint_path, tokenizer, noise_scheduler, vae,
|
||||
ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer,
|
||||
save_ckpt=not args.no_save_ckpt)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
else:
|
||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||
|
@ -670,7 +888,7 @@ def main(args):
|
|||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
|
@ -678,12 +896,13 @@ def main(args):
|
|||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
if zero_frequency_noise_ratio > 0.0:
|
||||
if zero_frequency_noise_ratio != None:
|
||||
if zero_frequency_noise_ratio < 0:
|
||||
zero_frequency_noise_ratio = 0
|
||||
|
||||
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
||||
noise = torch.randn_like(latents) + zero_frequency_noise
|
||||
else:
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
bsz = latents.shape[0]
|
||||
|
||||
|
@ -714,9 +933,35 @@ def main(args):
|
|||
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if return_loss:
|
||||
if args.min_snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
else:
|
||||
snr = compute_snr(timesteps, noise_scheduler)
|
||||
|
||||
mse_loss_weights = (
|
||||
torch.stack(
|
||||
[snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1
|
||||
).min(dim=1)[0]
|
||||
/ snr
|
||||
)
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
return model_pred, target, loss
|
||||
|
||||
else:
|
||||
return model_pred, target
|
||||
|
||||
def generate_samples(global_step: int, batch):
|
||||
nonlocal unet
|
||||
nonlocal text_encoder
|
||||
nonlocal unet_ema
|
||||
nonlocal text_encoder_ema
|
||||
|
||||
with isolate_rng():
|
||||
prev_sample_steps = sample_generator.sample_steps
|
||||
sample_generator.reload_config()
|
||||
|
@ -725,20 +970,76 @@ def main(args):
|
|||
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
models_info = []
|
||||
|
||||
if (args.ema_decay_rate is None) or args.ema_decay_sample_raw_training:
|
||||
models_info.append({"is_ema": False, "swap_required": False})
|
||||
|
||||
if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model:
|
||||
models_info.append({"is_ema": True, "swap_required": ema_device != device})
|
||||
|
||||
for model_info in models_info:
|
||||
|
||||
extra_info: str = ""
|
||||
|
||||
if model_info["is_ema"]:
|
||||
current_unet, current_text_encoder = unet_ema, text_encoder_ema
|
||||
extra_info = "_ema"
|
||||
else:
|
||||
current_unet, current_text_encoder = unet, text_encoder
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet_unloaded = unet.to(ema_device)
|
||||
del unet
|
||||
text_encoder_unloaded = text_encoder.to(ema_device)
|
||||
del text_encoder
|
||||
|
||||
current_unet = unet_ema.to(device)
|
||||
del unet_ema
|
||||
current_text_encoder = text_encoder_ema.to(device)
|
||||
del text_encoder_ema
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
|
||||
text_encoder=current_text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=inference_scheduler.config,
|
||||
diffusers_scheduler_config=inference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
|
||||
|
||||
# Cleanup
|
||||
del inference_pipe
|
||||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet = unet_unloaded.to(device)
|
||||
del unet_unloaded
|
||||
text_encoder = text_encoder_unloaded.to(device)
|
||||
del text_encoder_unloaded
|
||||
|
||||
unet_ema = current_unet.to(ema_device)
|
||||
del current_unet
|
||||
text_encoder_ema = current_text_encoder.to(ema_device)
|
||||
del current_text_encoder
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def make_save_path(epoch, global_step, prepend=""):
|
||||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
|
||||
|
||||
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation(global_step=0,
|
||||
|
@ -763,6 +1064,11 @@ def main(args):
|
|||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
|
||||
if args.load_settings_every_epoch:
|
||||
load_train_json_from_file(args)
|
||||
|
||||
|
||||
plugin_runner.run_on_epoch_start(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
|
@ -790,9 +1096,7 @@ def main(args):
|
|||
log_folder=log_folder,
|
||||
batch=batch)
|
||||
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
|
||||
|
||||
del target, model_pred
|
||||
|
||||
|
@ -802,6 +1106,21 @@ def main(args):
|
|||
|
||||
ed_optimizer.step(loss, step, global_step)
|
||||
|
||||
if args.ema_decay_rate != None:
|
||||
if ((global_step + 1) % args.ema_decay_interval) == 0:
|
||||
# debug_start_time = time.time() # Measure time
|
||||
|
||||
if args.disable_unet_training != True:
|
||||
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
|
||||
|
||||
if args.disable_textenc_training != True:
|
||||
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
|
||||
|
||||
# debug_end_time = time.time() # Measure time
|
||||
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
|
||||
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
|
||||
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
||||
steps_pbar.set_postfix({"loss/step": loss_step}, {"gs": global_step})
|
||||
|
@ -846,12 +1165,16 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
|
||||
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
|
||||
save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
|
||||
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
|
||||
save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
plugin_runner.run_on_step_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
|
@ -890,7 +1213,8 @@ def main(args):
|
|||
# end of training
|
||||
epoch = args.max_epochs
|
||||
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
|
||||
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -900,7 +1224,8 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
|
||||
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
||||
raise ex
|
||||
|
||||
|
@ -916,14 +1241,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, argv = argparser.parse_known_args()
|
||||
|
||||
if args.config is not None:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
with open(args.config, 'rt') as f:
|
||||
args.__dict__.update(json.load(f))
|
||||
if len(argv) > 0:
|
||||
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
load_train_json_from_file(args, report_load=True)
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
|
@ -972,7 +1290,18 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
|
||||
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
|
||||
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
|
||||
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide ema_decay_rate.")
|
||||
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
|
||||
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
|
||||
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
|
||||
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
|
||||
argparser.add_argument("--ema_decay_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
|
||||
|
||||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
|
|
|
@ -39,5 +39,15 @@
|
|||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.02
|
||||
"zero_frequency_noise_ratio": 0.02,
|
||||
"enable_zero_terminal_snr": false,
|
||||
"load_settings_every_epoch": false,
|
||||
"min_snr_gamma": null,
|
||||
"ema_decay_rate": null,
|
||||
"ema_decay_target": null,
|
||||
"ema_decay_interval": null,
|
||||
"ema_decay_device": null,
|
||||
"ema_decay_sample_raw_training": false,
|
||||
"ema_decay_sample_ema_model": false,
|
||||
"ema_decay_resume_model" : null
|
||||
}
|
||||
|
|
|
@ -181,7 +181,7 @@ class SampleGenerator:
|
|||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
|
@ -269,15 +269,15 @@ class SampleGenerator:
|
|||
prompt = prompts[prompt_idx]
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(str(batch[prompt_idx]))
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if batch[prompt_idx].wants_random_caption:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
|
||||
else:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
sample_index += 1
|
||||
|
||||
del result
|
||||
|
|
Loading…
Reference in New Issue