added some resolutions, option for val-loss pos-neg, fix wandb

This commit is contained in:
Victor Hall 2023-03-25 20:09:06 -04:00
parent 3744bc0dc9
commit 35d52b56e0
7 changed files with 110 additions and 38 deletions

View File

@ -13,6 +13,48 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Notes:
this is generated from an excel sheet and actual ratios are hand picked to
spread out the ratios evenly to avoid having super-finely defined buckets
Too many buckets means more "runt" steps with repeated images to fill batch than necessary
ex. we do not need both 1.0:1 and 1.125:1, they're almost identical ratios
Try to keep around <20 ratio buckets per resolution, should be plenty coverage everything between 1:1 and 4:1
More finely defined buckets will reduce cropping at the expense of more runt steps
"""
ASPECTS_1536 = [[1536,1536], # 2359296 1:1
[1728,1344],[1344,1728], # 2322432 1.286:1
[1792,1280],[1280,1792], # 2293760 1.4:1
[2048,1152],[1152,2048], # 2359296 1.778:1
[2304,1024],[1024,2304], # 2359296 2.25:1
[2432,960],[960,2432], # 2334720 2.53:1
[2624,896],[896,2624], # 2351104 2.929:1
[2816,832],[832,2816], # 2342912 3.385:1
[3072,768],[768,3072], # 2359296 4:1
]
ASPECTS_1408 = [[1408,1408], # 1982464 1:1
[1536,1280],[1280,1536], # 1966080 1.2:1
[1664,1152],[1152,1664], # 1916928 1.444:1
[1920,1024],[1024,1920], # 1966080 1.875:1
[2048,960],[960,2048], # 1966080 2.133:1
[2368,832],[832,2368], # 1970176 2.846:1
[2560,768],[768,2560], # 1966080 3.333:1
[2816,704],[704,3072], # 1982464 4:1
]
ASPECTS_1280 = [[1280,1280], # 1638400 1:1
[1408,1152],[1408,1344], # 1622016 1.222:1
[1600,1024],[1024,1600], # 1638400 1.563:1
[1792,896],[896,1792], # 1605632 2:1
[1920,832],[832,1920], # 1597440 2.308:1
[2112,768],[768,2112], # 1585152 2.75:1
[2304,704],[704,2304], # 1622016 3.27:1
[2560,640],[640,2560], # 1638400 4:1
]
ASPECTS_1152 = [[1152,1152], # 1327104 1:1
#[1216,1088],[1088,1216], # 1323008 1.118:1
[1280,1024],[1024,1280], # 1310720 1.25:1
@ -48,7 +90,7 @@ ASPECTS_1024 = [[1024,1024], # 1048576 1:1
]
ASPECTS_960 = [[960,960], # 921600 1:1
[1024,896],[896,1024], # 917504 1.143:1
#[1024,896],[896,1024], # 917504 1.143:1
[1088,832],[832,1088], # 905216 1.308:1
[1152,768],[768,1152], # 884736 1.5:1
[1280,704],[704,1280], # 901120 1.818:1
@ -56,11 +98,11 @@ ASPECTS_960 = [[960,960], # 921600 1:1
[1680,576],[576,1680], # 921600 2.778:1
#[1728,512],[512,1728], # 884736 3.375:1
[1792,512],[512,1792], # 917504 3.5:1
[2048,448],[448,2048], # 917504 4.714:1
[2048,448],[448,2048], # 917504 4.57:1
]
ASPECTS_896 = [[896,896], # 802816 1:1
[960,832],[832,960], # 798720 1.153:1
#[960,832],[832,960], # 798720 1.153:1
[1024,768],[768,1024], # 786432 1.333:1
[1088,704],[704,1088], # 765952 1.545:1
[1216,640],[640,1216], # 778240 1.9:1
@ -155,7 +197,7 @@ ASPECTS_384 = [[384,384], # 147456 1:1
ASPECTS_256 = [[256,256], # 65536 1:1
[384,192],[192,384], # 73728 2:1
[512,128],[128,512], # 65536 4:1
]
] # very few buckets available for 256 with 64 pixel increments
def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
if resolution < 256:
@ -173,6 +215,10 @@ def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
print(f" *** Unsupported resolution of {resolution}, check your resolution config")
print(f" *** Value must be between 512 and 1024")
raise e
def get_supported_resolutions():
all_image_sizes = __get_all_aspects()
return list(map(lambda sizes: sizes[0][0], all_image_sizes))
def __get_all_aspects():
return [ASPECTS_256,
@ -188,5 +234,7 @@ def __get_all_aspects():
ASPECTS_960,
ASPECTS_1024,
ASPECTS_1088,
ASPECTS_1152
ASPECTS_1152,
ASPECTS_1280,
ASPECTS_1536,
]

View File

@ -40,12 +40,12 @@ class EveryDreamValidator:
val_config_path: Optional[str],
default_batch_size: int,
resolution: int,
log_writer: SummaryWriter):
log_writer: SummaryWriter,
):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
self.resolution = resolution
self.log_writer = log_writer
self.config = {
'batch_size': default_batch_size,
@ -57,7 +57,9 @@ class EveryDreamValidator:
'val_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15
'stabilize_split_proportion': 0.15,
'use_relative_loss': False,
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
@ -67,7 +69,7 @@ class EveryDreamValidator:
self.val_loss_offset = None
self.loss_val_history = []
self.val_loss_window_size = 4 # todo: arg for this?
self.val_loss_window_size = 5 # todo: arg for this?
@property
def batch_size(self):
@ -80,6 +82,10 @@ class EveryDreamValidator:
@property
def seed(self):
return self.config['seed']
@property
def use_relative_loss(self):
return self.config['use_relative_loss']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
@ -117,14 +123,19 @@ class EveryDreamValidator:
if self.val_loss_offset is None:
self.val_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/val",
scalar_value=self.val_loss_offset + mean_loss,
scalar_value=mean_loss if not self.use_relative_loss else self.val_loss_offset + mean_loss,
global_step=global_step)
self.loss_val_history.append(mean_loss)
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss shows diverging")
# todo: signal stop?
self.track_loss_trend(mean_loss)
def track_loss_trend(self, mean_loss):
self.loss_val_history.append(mean_loss)
if len(self.loss_val_history) > ((self.val_loss_window_size * 2) + 1):
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss shows diverging. Check your val/loss graph.")
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:

View File

@ -1,6 +1,6 @@
aiohttp==3.8.4
colorama==0.4.6
diffusers[torch]>=0.13.0
diffusers[torch]>=0.14.0
ftfy==6.1.1
ipyevents
ipywidgets
@ -14,7 +14,7 @@ pyfakefs
pynvml==11.5.0
pyre-extensions==0.0.30
pytorch-lightning==1.9.2
tensorboard==2.11.0
tensorboard==2.12.0
transformers==4.25.1
triton>=2.0.0a2
wandb

View File

@ -13,5 +13,5 @@
"betas": [0.9, 0.999],
"epsilon": 1e-8,
"weight_decay": 0.010,
"text_encoder_lr_scale": 1.0
"text_encoder_lr_scale": 0.50
}

View File

@ -49,6 +49,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from accelerate.utils import set_seed
import wandb
import webbrowser
from torch.utils.tensorboard import SummaryWriter
from data.data_loader import DataLoaderMultiAspect
@ -120,8 +121,12 @@ def setup_local_logger(args):
format="%(asctime)s %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
console_handler = logging.StreamHandler(sys.stdout)
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" in msg.getMessage())
logging.getLogger().addHandler(console_handler)
import warnings
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
#from PIL import Image
return datetimestamp
@ -138,7 +143,7 @@ def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
torch.save(optimizer.state_dict(), path)
def load_optimizer(optimizer, path: str):
def load_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
Loads the optimizer state
"""
@ -491,17 +496,24 @@ def main(args):
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
optimizer_config = json.load(f)
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name,
sync_tensorboard=True,
dir=args.logdir,
config=args,
name=args.run_name,
)
if args.wandb:
wandb.tensorboard.patch(root_logdir=log_folder, pytorch=False, tensorboard_x=False, save=False)
wandb_run = wandb.init(
project=args.project_name,
config={"main_cfg": vars(args), "optimizer_cfg": optimizer_config},
name=args.run_name,
#sync_tensorboard=True, # broken?
#dir=log_folder, # only for save, just duplicates the TB log to /{log_folder}/wandb ...
)
try:
if webbrowser.get():
webbrowser.open(wandb_run.url, new=2)
except Exception:
pass
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=10,
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
flush_secs=20,
comment=args.run_name if args.run_name is not None else log_time,
)
betas = [0.9, 0.999]
@ -929,8 +941,7 @@ def main(args):
if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
supported_precisions = ['fp16', 'fp32']
supported_resolutions = aspects.get_supported_resolutions()
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, argv = argparser.parse_known_args()

View File

@ -7,7 +7,8 @@
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
"seed": "The seed to use when running validation and stabilization passes."
"seed": "The seed to use when running validation and stabilization passes.",
"use_relative_loss": "logs val/loss as negative relative to first pre-train val/loss value"
},
"validate_training": true,
"val_split_mode": "automatic",
@ -16,5 +17,6 @@
"stabilize_training_loss": false,
"stabilize_split_proportion": 0.15,
"every_n_epochs": 1,
"seed": 555
"seed": 555,
"use_relative_loss": false
}

View File

@ -4,7 +4,7 @@ echo should be in venv here
cd .
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
pip install transformers==4.25.1
pip install transformers==4.27.1
pip install diffusers[torch]==0.13.0
pip install pynvml==11.4.1
pip install bitsandbytes==0.35.0
@ -13,7 +13,7 @@ pip install ftfy==6.1.1
pip install aiohttp==3.8.3
pip install tensorboard>=2.11.0
pip install protobuf==3.20.1
pip install wandb==0.13.6
pip install wandb==0.14.0
pip install pyre-extensions==0.0.23
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall