updating reqs

This commit is contained in:
Victor Hall 2023-11-02 21:54:29 -04:00
parent 3150f7d299
commit 4fb64fed66
4 changed files with 75 additions and 20 deletions

View File

@ -1,7 +1,7 @@
###################
# Builder Stage
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS builder
LABEL org.opencontainers.image.licenses="AGPL-1.0-only"
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS builder
LABEL org.opencontainers.image.licenses="AGPL-3.0-only"
ARG DEBIAN_FRONTEND=noninteractive
# Don't write .pyc bytecode
@ -32,13 +32,13 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ADD requirements-build.txt /build
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m venv ${VIRTUAL_ENV} && \
pip install -U -I torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118" && \
pip install -U -I torch==2.1.0+cu121 torchvision==0.16.0+cu121 --extra-index-url "https://download.pytorch.org/whl/cu121" && \
pip install -r requirements-build.txt && \
pip install --no-deps xformers==0.0.18
pip install --no-deps xformers==0.0.20.post7
###################
# Runtime Stage
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 as runtime
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 as runtime
# Use bash shell
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
@ -77,9 +77,9 @@ RUN echo "source ${VIRTUAL_ENV}/bin/activate" >> /root/.bashrc
# Workaround for:
# https://github.com/TimDettmers/bitsandbytes/issues/62
# https://github.com/TimDettmers/bitsandbytes/issues/73
ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudart.so
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libnvrtc.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvrtc.so
ENV LD_LIBRARY_PATH="/usr/local/cuda-12.1/targets/x86_64-linux/lib/"
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12.1.55 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudart.so
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libnvrtc.so.12.1.55 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libnvrtc.so
# Vast.ai SSH ignores ENV vars unless fully exported
# Exporting anything with an _ should cover the bases

View File

@ -42,11 +42,13 @@ class EveryDreamOptimizer():
text_encoder: text encoder model parameters
unet: unet model parameters
"""
def __init__(self, args, optimizer_config, text_encoder, unet, epoch_len):
def __init__(self, args, optimizer_config, text_encoder, unet, epoch_len, log_writer=None):
del optimizer_config["doc"]
print(f"\n raw optimizer_config:")
pprint.pprint(optimizer_config)
self.epoch_len = epoch_len
self.unet = unet # needed for weight norm logging, unet.parameters() has to be called again, Diffusers quirk
self.log_writer = log_writer
self.te_config, self.base_config = self.get_final_optimizer_configs(args, optimizer_config)
self.te_freeze_config = optimizer_config.get("text_encoder_freezing", {})
print(f" Final unet optimizer config:")
@ -57,10 +59,16 @@ class EveryDreamOptimizer():
self.grad_accum = args.grad_accum
self.clip_grad_norm = args.clip_grad_norm
self.apply_grad_scaler_step_tweaks = optimizer_config.get("apply_grad_scaler_step_tweaks", True)
self.log_grad_norm = optimizer_config.get("log_grad_norm", True)
self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder)
self.unet_params = unet.parameters()
with torch.no_grad():
log_action = lambda n, label: logging.info(f"{Fore.LIGHTBLUE_EX} {label} weight normal: {n}{Style.RESET_ALL}")
self._log_weight_normal(text_encoder.text_model.encoder.layers.parameters(), "text encoder", log_action)
self._log_weight_normal(unet.parameters(), "unet", log_action)
self.optimizers = []
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args,
self.text_encoder_params,
@ -71,7 +79,6 @@ class EveryDreamOptimizer():
self.lr_schedulers = []
schedulers = self.create_lr_schedulers(args, optimizer_config)
self.lr_schedulers.extend(schedulers)
#print(self.lr_schedulers)
self.load(args.resume_ckpt)
@ -85,6 +92,34 @@ class EveryDreamOptimizer():
logging.info(f" Grad scaler enabled: {self.scaler.is_enabled()} (amp mode)")
def _log_gradient_normal(self, parameters: Generator, label: str, log_action=None):
total_norm = self._get_norm(parameters, lambda p: p.grad.data)
log_action(total_norm, label)
def _log_weight_normal(self, parameters: Generator, label: str, log_action=None):
total_norm = self._get_norm(parameters, lambda p: p.data)
log_action(total_norm, label)
def _calculate_normal(param, param_type):
if param_type(param) is not None:
return param_type(param).norm(2).item() ** 2
else:
return 0.0
def _get_norm(self, parameters: Generator, param_type):
total_norm = 0
for p in parameters:
param = param_type(p)
total_norm += self._calculate_norm(param, p)
total_norm = total_norm ** (1. / 2)
return total_norm
def _calculate_norm(self, param, p):
if param is not None:
return param.norm(2).item() ** 2
else:
return 0.0
def step(self, loss, step, global_step):
self.scaler.scale(loss).backward()
@ -92,13 +127,31 @@ class EveryDreamOptimizer():
if self.clip_grad_norm is not None:
for optimizer in self.optimizers:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
if self.log_grad_norm:
pre_clip_norm = torch.nn.utils.clip_grad_norm_(parameters=self.unet.parameters(), max_norm=float('inf'))
self.log_writer.add_scalar("optimizer/unet_pre_clip_norm", pre_clip_norm, global_step)
pre_clip_norm = torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=float('inf'))
self.log_writer.add_scalar("optimizer/te_pre_clip_norm", pre_clip_norm, global_step)
unet_grad_norm = torch.nn.utils.clip_grad_norm_(parameters=self.unet.parameters(), max_norm=self.clip_grad_norm)
self.log_writer.add_scalar("optimizer/unet_grad_norm", unet_grad_norm, global_step)
te_grad_norm = torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
self.log_writer.add_scalar("optimizer/te_grad_norm", te_grad_norm, global_step)
for optimizer in self.optimizers:
self.scaler.step(optimizer)
self.scaler.update()
if self.log_grad_norm and self.log_writer:
log_info_unet_fn = lambda n, label: self.log_writer.add_scalar(label, n, global_step)
log_info_te_fn = lambda n, label: self.log_writer.add_scalar(label, n, global_step)
with torch.no_grad():
self._log_gradient_normal(self.unet_params, "optimizer/unet_grad_norm", log_info_unet_fn)
self._log_gradient_normal(self.text_encoder_params, "optimizer/te_grad_norm", log_info_te_fn)
self._zero_grad(set_to_none=True)
for scheduler in self.lr_schedulers:
@ -278,6 +331,7 @@ class EveryDreamOptimizer():
no_prox = False # ????, dadapt_adan
use_bias_correction = True # suggest by prodigy github
growth_rate=float("inf") # dadapt various, no idea what a sane default is
safeguard_warmup = True # per recommendation from prodigy documentation
if local_optimizer_config is not None:
betas = local_optimizer_config.get("betas", betas)
@ -290,6 +344,7 @@ class EveryDreamOptimizer():
decouple = local_optimizer_config.get("decouple", decouple)
momentum = local_optimizer_config.get("momentum", momentum)
growth_rate = local_optimizer_config.get("growth_rate", growth_rate)
safeguard_warmup = local_optimizer_config.get("safeguard_warmup", safeguard_warmup)
if args.lr is not None:
curr_lr = args.lr
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
@ -322,7 +377,6 @@ class EveryDreamOptimizer():
elif optimizer_name == "prodigy":
from prodigyopt import Prodigy
opt_class = Prodigy
safeguard_warmup = True # per recommendation from prodigy documentation
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
@ -449,6 +503,7 @@ class EveryDreamOptimizer():
unfreeze_final_layer_norm = not self.te_freeze_config["freeze_final_layer_norm"]
parameters = itertools.chain([])
if unfreeze_embeddings:
parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters())
else:

View File

@ -20,7 +20,7 @@
},
"base": {
"optimizer": "adamw8bit",
"lr": 2e-6,
"lr": 1e-6,
"lr_scheduler": "constant",
"lr_decay_steps": null,
"lr_warmup_steps": null,
@ -30,7 +30,7 @@
},
"text_encoder_overrides": {
"optimizer": null,
"lr": null,
"lr": 3e-7,
"lr_scheduler": null,
"lr_decay_steps": null,
"lr_warmup_steps": null,

View File

@ -1,7 +1,7 @@
torch==2.0.1
torchvision==0.15.2
transformers==4.29.2
diffusers[torch]==0.18.0
torch==2.1.0
torchvision==0.16.0
transformers==4.35.0
diffusers[torch]==0.21.4
pynvml==11.4.1
bitsandbytes==0.41.1
ftfy==6.1.1