Added CoordinateDoWG and ScalarDoWG

This commit is contained in:
SargeZT 2023-06-08 10:39:46 -05:00
parent 7e09b6dc29
commit 4861d96ec2
2 changed files with 9 additions and 0 deletions

View File

@ -303,6 +303,14 @@ class EveryDreamOptimizer():
)
elif optimizer_name == "adamw":
opt_class = torch.optim.AdamW
if "dowg" in optimizer_name:
from dowg import CoordinateDoWG, ScalarDoWG
if optimizer_name == "coordinate_dowg":
opt_class = CoordinateDoWG
elif optimizer_name == "scalar_dowg":
opt_class = ScalarDoWG
else:
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are coordinate_dowg and scalar_dowg")
elif optimizer_name in ["dadapt_adam", "dadapt_lion", "dadapt_sgd"]:
import dadaptation

View File

@ -13,6 +13,7 @@ xformers==0.0.20
pytorch-lightning==1.6.5
OmegaConf==2.2.3
numpy==1.23.5
dowg
lion-pytorch
compel~=1.1.3
OmegaConf==2.2.3