Added CoordinateDoWG and ScalarDoWG
This commit is contained in:
parent
7e09b6dc29
commit
4861d96ec2
|
@ -303,6 +303,14 @@ class EveryDreamOptimizer():
|
|||
)
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = torch.optim.AdamW
|
||||
if "dowg" in optimizer_name:
|
||||
from dowg import CoordinateDoWG, ScalarDoWG
|
||||
if optimizer_name == "coordinate_dowg":
|
||||
opt_class = CoordinateDoWG
|
||||
elif optimizer_name == "scalar_dowg":
|
||||
opt_class = ScalarDoWG
|
||||
else:
|
||||
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are coordinate_dowg and scalar_dowg")
|
||||
elif optimizer_name in ["dadapt_adam", "dadapt_lion", "dadapt_sgd"]:
|
||||
import dadaptation
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ xformers==0.0.20
|
|||
pytorch-lightning==1.6.5
|
||||
OmegaConf==2.2.3
|
||||
numpy==1.23.5
|
||||
dowg
|
||||
lion-pytorch
|
||||
compel~=1.1.3
|
||||
OmegaConf==2.2.3
|
||||
|
|
Loading…
Reference in New Issue