Merge remote-tracking branch 'upstream/main' into hf_model_download
This commit is contained in:
commit
d24dd681c0
|
@ -17,6 +17,10 @@ Covers install, setup of base models, startning training, basic tweaking, and lo
|
|||
|
||||
Behind the scenes look at how the trainer handles multiaspect and crop jitter
|
||||
|
||||
### Tools repo
|
||||
|
||||
Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.
|
||||
|
||||
## Docs
|
||||
|
||||
[Setup and installation](doc/SETUP.md)
|
||||
|
|
|
@ -0,0 +1,439 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "blaLMSbkPHhG"
|
||||
},
|
||||
"source": [
|
||||
"# EveryDream2 Colab Edition\n",
|
||||
"\n",
|
||||
"Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
|
||||
"\n",
|
||||
"And join the discord: https://discord.gg/uheqxU6sXN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "WsYIcz9HY9lx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title # Install python 3.10 \n",
|
||||
"import os\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"!wget https://github.com/korakot/kora/releases/download/v0.10/py310.sh\n",
|
||||
"!bash ./py310.sh -b -f -p /usr/local\n",
|
||||
"!python -m ipykernel install --name \"py310\" --user\n",
|
||||
"clear_output()\n",
|
||||
"os.kill(os.getpid(), 9)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "f2cdMtCt9Wb6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Verify python version, should be 3.10.something\n",
|
||||
"!python --version"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "d1di4EC6ygw1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Optional connect Gdrive\n",
|
||||
"#@markdown # but strongly recommended\n",
|
||||
"#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
|
||||
"\n",
|
||||
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
|
||||
"from google.colab import drive\n",
|
||||
"drive.mount('/content/drive')\n",
|
||||
"\n",
|
||||
"!mkdir -p /content/drive/MyDrive/everydreamlogs/ckpt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "hAuBbtSvGpau"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@markdown # Install Dependencies\n",
|
||||
"#@markdown This will take a couple minutes, be patient and watch the output for \"DONE!\"\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"from subprocess import getoutput\n",
|
||||
"s = getoutput('nvidia-smi')\n",
|
||||
"!pip install -q torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url \"https://download.pytorch.org/whl/cu116\"\n",
|
||||
"!pip install -q transformers==4.25.1\n",
|
||||
"!pip install -q diffusers[torch]==0.10.2\n",
|
||||
"!pip install -q pynvml==11.4.1\n",
|
||||
"!pip install -q bitsandbytes==0.35.0\n",
|
||||
"!pip install -q ftfy==6.1.1\n",
|
||||
"!pip install -q aiohttp==3.8.3\n",
|
||||
"!pip install -q tensorboard>=2.11.0\n",
|
||||
"!pip install -q protobuf==3.20.1\n",
|
||||
"!pip install -q wandb==0.13.6\n",
|
||||
"!pip install -q pyre-extensions==0.0.23\n",
|
||||
"if \"A100\" in s:\n",
|
||||
" !pip install -q https://huggingface.co/industriaditat/xformers_precompiles/blob/main/A100_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
|
||||
"else:\n",
|
||||
" !pip install -q https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/T4_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
|
||||
"!pip install -q pytorch-lightning==1.6.5\n",
|
||||
"!pip install -q OmegaConf==2.2.3\n",
|
||||
"!pip install -q numpy==1.23.5\n",
|
||||
"!pip install -q colorama\n",
|
||||
"!pip install -q keyboard\n",
|
||||
"clear_output()\n",
|
||||
"!git clone https://github.com/victorchall/EveryDream2trainer.git\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"!wget \"https://raw.githubusercontent.com/nawnie/EveryDream2trainer/main/train_colab.py\"\n",
|
||||
"clear_output()\n",
|
||||
"print(\"DONE!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "unaffeqGP_0A"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Get A Base Model\n",
|
||||
"#@markdown Choose SD1.5 or Waifu Diffusion 1.3 from the dropdown, or paste your own URL in the box\n",
|
||||
"\n",
|
||||
"#@markdown If you already did this once with Gdrive connected, you can skip this step as the cached copy is on your gdrive\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"!mkdir input\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"!python utils/get_yamls.py\n",
|
||||
"MODEL_URL = \"https://huggingface.co/panopstor/EveryDream/resolve/main/sd_v1-5_vae.ckpt\" #@param [\"https://huggingface.co/panopstor/EveryDream/resolve/main/sd_v1-5_vae.ckpt\", \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt\"] {allow-input: true}\n",
|
||||
"print(\"Downloading \")\n",
|
||||
"!wget $MODEL_URL\n",
|
||||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"\n",
|
||||
"clear_output()\n",
|
||||
"print(\"DONE!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nEzuEYH0536C"
|
||||
},
|
||||
"source": [
|
||||
"In order to train, you need a base model on which to train. This is a one-time setup to configure base models when you want to use a particular base. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "tPvQSo6ScF2c"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"#@title Setup conversion\n",
|
||||
"\n",
|
||||
"#@markdown **If you already did this once with Gdrive connected, you can skip this step as the cached copy is on your gdrive.** \n",
|
||||
"# \n",
|
||||
"# If you are not sure, look in your Gdrive for `everydreamlogs/ckpt` and see if you have a folder with the `save_name` below.\n",
|
||||
"\n",
|
||||
"#@markdown Pick the `model_type` in the dropdown. This is the model type that you are converting and you downloaded above. This is important as it will determine the model architecture and the correct settings to use.\n",
|
||||
"\n",
|
||||
"#@markdown * `SD1x` is all SD1.x based models *(SD1.4, SD1.5, Waifu Diffusion 1.3, etc)*\n",
|
||||
"\n",
|
||||
"#@markdown * `SD2_512_base` is the SD2 512 base model\n",
|
||||
"\n",
|
||||
"#@markdown * `SD21` is all SD2 768 models. *(ex. SD2.1 768, or trained models based on that)*\n",
|
||||
"\n",
|
||||
"#@markdown If you are not sure, double check the model author's page or ask for help on [Discord](https://discord.gg/uheqxU6sXN).\n",
|
||||
"model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
|
||||
"\n",
|
||||
"#@markdown This is the temporary ckpt file that was downloaded above. If you downloaded a different model, you can change this. *Hint: look at your file manager in the EveryDream2trainer folder for .ckpt files*.\n",
|
||||
"base_path = \"/content/EveryDream2trainer/sd_v1-5_vae.ckpt\" #@param {type:\"string\"}\n",
|
||||
"\n",
|
||||
"#@markdown The name that you will use when selecting this model in the future training sessons.\n",
|
||||
"save_name = \"SD15\" #@param{type:\"string\"}\n",
|
||||
"\n",
|
||||
"#@markdown If you are using Gdrive, this will save the converted model to your Gdrive for future use so you can skip downloading and converting the model.\n",
|
||||
"cache_to_gdrive = True #@param{type:\"boolean\"}\n",
|
||||
"\n",
|
||||
"if cache_to_gdrive:\n",
|
||||
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
|
||||
"\n",
|
||||
"img_size = 512\n",
|
||||
"upscale_attention = False\n",
|
||||
"if model_type == \"SD1x\":\n",
|
||||
" inference_yaml = \"v1-inference.yaml\"\n",
|
||||
"elif model_type == \"SD2_512_base\":\n",
|
||||
" upscale_attention = True\n",
|
||||
" inference_yaml = \"v2-inference.yaml\"\n",
|
||||
"elif model_type == \"SD21\":\n",
|
||||
" upscale_attention = True\n",
|
||||
" inference_yaml = \"v2-inference-v.yaml\"\n",
|
||||
" img_size = 768\n",
|
||||
"\n",
|
||||
"print(base_path)\n",
|
||||
"print(inference_yaml)\n",
|
||||
"\n",
|
||||
"!python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
|
||||
"--original_config_file {inference_yaml} \\\n",
|
||||
"--image_size {img_size} \\\n",
|
||||
"--checkpoint_path {base_path} \\\n",
|
||||
"--prediction_type epsilon \\\n",
|
||||
"--upcast_attn False \\\n",
|
||||
"--dump_path {save_name}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "bLpcvpGJB4Gu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Pick your base model from a diffusers model saved to your Gdrive (converted above)\n",
|
||||
"\n",
|
||||
"#@markdown Do not skip this cell.\n",
|
||||
"\n",
|
||||
"#@markdown * If you have preveiously saved diffusers on your drive you can select it here\n",
|
||||
"\n",
|
||||
"#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n",
|
||||
"\n",
|
||||
"#@markdown The default for SD1.5 converted above would be */content/drive/MyDrive/everydreamlogs/ckpt/SD15*\n",
|
||||
"Resume_Model = \"/content/drive/MyDrive/everydreamlogs/ckpt/SD15\" #@param{type:\"string\"} \n",
|
||||
"save_name = Resume_Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "JXVu-W2lCjwX"
|
||||
},
|
||||
"source": [
|
||||
"For a more indepth Explanation of each of these paramaters check out /content/EveryDream2trainer/doc.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"After youve tried a few models you will find /content/EveryDream2trainer/doc/ATWEAKING.md to be extremly helpful."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "j9pEI69WXS9w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title \n",
|
||||
"#@markdown # Run Everydream 2\n",
|
||||
"#@markdown If you want to use a .json config or upload your own, skip this cell and run the cell below instead\n",
|
||||
"\n",
|
||||
"#@markdown * Save logs and output ckpts to Gdrive (strongly suggested)\n",
|
||||
"Save_to_Gdrive = True #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * Use resume to contnue training you just ran, will also find latest diffusers log in your Gdrive to continue.\n",
|
||||
"resume = False #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * Checkpointing Saves Vram to allow larger batch sizes minor slow down on a single batch size but will can allow room for a higher traning resolution (suggested on Colab Free tier, turn off for A100)\n",
|
||||
"Gradient_checkpointing = True #@param{type:\"boolean\"}\n",
|
||||
"Disable_Xformers = False\n",
|
||||
"#@markdown * Tag shuffling, mainly for booru training. Best to just read this if interested in shufflng tags /content/EveryDream2trainer/doc/SHUFFLING_TAGS.md\n",
|
||||
"shuffle_tags = False #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * You can turn off the text encoder training (generally not suggested)\n",
|
||||
"Disable_text_Encoder= False #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * Name your project so you can find it in your logs\n",
|
||||
"Project_Name = \"my_project\" #@param{type: 'string'}\n",
|
||||
"\n",
|
||||
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data and not wreck the model. \n",
|
||||
"#@markdown Leave this default unless you are very comfortable with training and know what you are doing.\n",
|
||||
"\n",
|
||||
"Learning_Rate = 1e-6 #@param{type: 'number'}\n",
|
||||
"\n",
|
||||
"#@markdown * A learning rate scheduler can change your learning rate as training progresses.\n",
|
||||
"\n",
|
||||
"#@markdown I recommend sticking with constant until you are comfortable with general training. \n",
|
||||
"\n",
|
||||
"Schedule = \"constant\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
|
||||
"\n",
|
||||
"#@markdown * Resolution to train at (recommend 512). Higher resolution will require lower batch size (below).\n",
|
||||
"Resolution = 512 #@param {type:\"slider\", min:256, max:768, step:64}\n",
|
||||
"\n",
|
||||
"#@markdown * Batch size is also another \"hyperparamter\" of itself and there are tradeoffs. It may not always be best to use the highest batch size possible. Once of the primary reasons to change it is if you get \"CUDA out of memory\" errors where lowering the value may help.\n",
|
||||
"\n",
|
||||
"#@markdown * Batch size impacts VRAM use. 4 should work on SD1.x models and 3 for SD2.x models at 512 resolution. Lower this if you get CUDA out of memory errors.\n",
|
||||
"\n",
|
||||
"Batch_Size = 4 #@param{type: 'number'}\n",
|
||||
"\n",
|
||||
"#@markdown * Gradient accumulation is sort of like a virtual batch size increase use this to increase batch size with out increasing vram usage\n",
|
||||
"#@markdown * Increasing this will not have much impact on VRAM use.\n",
|
||||
"#@markdown * In colab free teir you can expect the fastest proformance from a batch of 4 and a gradient step of 2 giving us a total batch size of 8 at 512 resolution \n",
|
||||
"#@markdown * Due to bucketng you may need to decresse batch size to 3\n",
|
||||
"#@markdown * Remember more gradient accumulation (or batch size) doesn't automatically mean better\n",
|
||||
"\n",
|
||||
"Gradient_steps = 1 #@param{type:\"slider\", min:1, max:10, step:1}\n",
|
||||
"\n",
|
||||
"#@markdown * Location on your Gdrive where your training images are.\n",
|
||||
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
|
||||
"dataset = Dataset_Location\n",
|
||||
"model = save_name\n",
|
||||
"\n",
|
||||
"#@markdown * Max Epochs to train for, this defines how many total times all your training data is used.\n",
|
||||
"\n",
|
||||
"Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:5}\n",
|
||||
"\n",
|
||||
"#@markdown * How often to save checkpoints.\n",
|
||||
"Save_every_N_epoch = 20 #@param{type:\"integer\"}\n",
|
||||
"\n",
|
||||
"#@markdown * Test sample generation steps, how often to generate samples during training.\n",
|
||||
"\n",
|
||||
"#@markdown You can set your own sample prompts by adding them, one line at a time, to `/content/EveryDream2trainer/sample_prompts.txt`. If left empty, it will use the captions from your training images.\n",
|
||||
"\n",
|
||||
"#@markdown Use the steps_between_samples to set how often the samples are generated.\n",
|
||||
"Steps_between_samples = 300 #@param{type:\"integer\"}\n",
|
||||
"\n",
|
||||
"#@markdown * That's it! Run the cell!\n",
|
||||
"\n",
|
||||
"Drive=\"\"\n",
|
||||
"if Save_to_Gdrive:\n",
|
||||
" Drive = \"--logdir /content/drive/MyDrive/everydreamlogs --save_ckpt_dir /content/drive/MyDrive/everydreamlogs/ckpt\"\n",
|
||||
"\n",
|
||||
"if Max_Epochs==0:\n",
|
||||
" Max_Epoch=1\n",
|
||||
"\n",
|
||||
"if resume:\n",
|
||||
" model = \"findlast\"\n",
|
||||
"\n",
|
||||
"Gradient = \"\"\n",
|
||||
"if Gradient_checkpointing:\n",
|
||||
" Gradient = \"--gradient_checkpointing \"\n",
|
||||
"if \"A100\" in s:\n",
|
||||
" Gradient = \"\"\n",
|
||||
"\n",
|
||||
"DX = \"\" \n",
|
||||
"if Disable_Xformers:\n",
|
||||
" DX = \"--disable_xformers \"\n",
|
||||
"\n",
|
||||
"shuffle = \"\"\n",
|
||||
"if shuffle_tags:\n",
|
||||
" shuffle = \"--shuffle_tags \"\n",
|
||||
"\n",
|
||||
"textencode = \"\"\n",
|
||||
"if Disable_text_Encoder:\n",
|
||||
" textencode = \"--disable_textenc_training Train_text \"\n",
|
||||
"\n",
|
||||
"!python train_colab.py --resume_ckpt \"$model\" \\\n",
|
||||
" $textencode \\\n",
|
||||
" $Gradient \\\n",
|
||||
" $shuffle \\\n",
|
||||
" $Drive \\\n",
|
||||
" $DX \\\n",
|
||||
" --amp \\\n",
|
||||
" --batch_size $Batch_Size \\\n",
|
||||
" --grad_accum $Gradient_steps \\\n",
|
||||
" --cond_dropout 0.00 \\\n",
|
||||
" --data_root \"$dataset\" \\\n",
|
||||
" --flip_p 0.00 \\\n",
|
||||
" --lr $Learning_Rate \\\n",
|
||||
" --lr_decay_steps 0 \\\n",
|
||||
" --lr_scheduler \"$Schedule\" \\\n",
|
||||
" --lr_warmup_steps 0 \\\n",
|
||||
" --max_epochs $Max_Epochs \\\n",
|
||||
" --project_name \"$Project_Name\" \\\n",
|
||||
" --resolution $Resolution \\\n",
|
||||
" --sample_prompts \"sample_prompts.txt\" \\\n",
|
||||
" --sample_steps $Steps_between_samples \\\n",
|
||||
" --save_every_n_epoch $Save_every_N_epoch \\\n",
|
||||
" --seed 555 \\\n",
|
||||
" --shuffle_tags \\\n",
|
||||
" --useadam8bit \\\n",
|
||||
" --notebook\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Iuoa_1B9jRGU"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Alternate startup script\n",
|
||||
"#@markdown * Edit train.json to setup your paramaters\n",
|
||||
"#@markdown * Edit chain0.json to make use of chaining\n",
|
||||
"#@markdown * make sure to check each confguration you will need 1 Json per chain length 3 are provided\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"Chain_Length=0 #@param{type:\"integer\"}\n",
|
||||
"l = Chain_Length \n",
|
||||
"I=0 #repeat counter\n",
|
||||
"if l == None or l == 0:\n",
|
||||
" l=1\n",
|
||||
"while l > 0:\n",
|
||||
" !python train_colab.py --config chain{I}.json\n",
|
||||
" l -= 1\n",
|
||||
" I =+ 1"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)]"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e602395b73d27e246c3f66de86a1ed4dc1e5a85e8356fd1a2f027b9d2f1f8162"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -13,9 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
|
||||
import yaml
|
||||
from PIL import Image
|
||||
|
@ -46,7 +48,7 @@ class DataLoaderMultiAspect():
|
|||
self.log_folder = log_folder
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.runts = []
|
||||
self.has_scanned = False
|
||||
|
||||
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
|
||||
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
|
||||
|
@ -56,135 +58,71 @@ class DataLoaderMultiAspect():
|
|||
|
||||
self.__recurse_data_root(self=self, recurse_root=data_root)
|
||||
random.Random(seed).shuffle(self.image_paths)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
|
||||
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
|
||||
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
|
||||
print(f"DLMA Loaded {len(self.prepared_train_data)} images")
|
||||
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
|
||||
|
||||
def shuffle(self):
|
||||
self.runts = []
|
||||
self.seed = self.seed + 1
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=self.batch_size, debug_level=0)
|
||||
|
||||
def unzip_all(self, path):
|
||||
try:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
logging.info(f"Unzipping {file}")
|
||||
with zipfile.ZipFile(path, 'r') as zip_ref:
|
||||
zip_ref.extractall(path)
|
||||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
def get_all_images(self):
|
||||
return self.image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
caption = fallback_caption
|
||||
pass
|
||||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
with open(file_path, "r") as stream:
|
||||
try:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ)
|
||||
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
return fallback_caption
|
||||
|
||||
@staticmethod
|
||||
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
"""
|
||||
Splits a string by "," into the main prompt and additional tags with equal weights
|
||||
Deals with multiply.txt whole and fractional numbers
|
||||
"""
|
||||
split_caption = caption_string.split(",")
|
||||
main_prompt = split_caption.pop(0).strip()
|
||||
tags = []
|
||||
for tag in split_caption:
|
||||
tags.append(tag.strip())
|
||||
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
|
||||
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
|
||||
epoch_size = len(self.prepared_train_data)
|
||||
picked_images = []
|
||||
|
||||
return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
# add by whole number part first and decrement multiplier in copy
|
||||
for iti in data_copy:
|
||||
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
|
||||
while iti.multiplier >= 1.0:
|
||||
picked_images.append(iti)
|
||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
|
||||
iti.multiplier -= 1.0
|
||||
|
||||
def __prescan_images(self, image_paths: list, flip_p=0.0):
|
||||
remaining = epoch_size - len(picked_images)
|
||||
|
||||
assert remaining >= 0, "Something went wrong with the multiplier calculation"
|
||||
#print(f"Remaining to fill epoch after whole number adds: {remaining}")
|
||||
#print(f"Remaining in data copy: {len(data_copy)}")
|
||||
|
||||
# add by renaming fractional numbers by random chance
|
||||
while remaining > 0:
|
||||
for iti in data_copy:
|
||||
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
|
||||
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
|
||||
picked_images.append(iti)
|
||||
remaining -= 1
|
||||
data_copy.remove(iti)
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
del data_copy
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||
"""
|
||||
Create ImageTrainItem objects with metadata for hydration later
|
||||
returns the current list of images including their captions in a randomized order,
|
||||
sorted into buckets with same sized images
|
||||
if dropout_fraction < 1.0, only a subset of the images will be returned
|
||||
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
|
||||
:param dropout_fraction: must be between 0.0 and 1.0.
|
||||
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
||||
"""
|
||||
decorated_image_train_items = []
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||
self.seed += 1
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
file_path_without_ext = os.path.splitext(pathname)[0]
|
||||
yaml_file_path = file_path_without_ext + ".yaml"
|
||||
txt_file_path = file_path_without_ext + ".txt"
|
||||
caption_file_path = file_path_without_ext + ".caption"
|
||||
if dropout_fraction < 1.0:
|
||||
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
|
||||
else:
|
||||
picked_images = self.__pick_multiplied_set(randomizer)
|
||||
|
||||
if os.path.exists(yaml_file_path):
|
||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||
elif os.path.exists(txt_file_path):
|
||||
caption = self.__read_caption_from_file(txt_file_path, caption)
|
||||
elif os.path.exists(caption_file_path):
|
||||
caption = self.__read_caption_from_file(caption_file_path, caption)
|
||||
randomizer.shuffle(picked_images)
|
||||
|
||||
try:
|
||||
image = Image.open(pathname)
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
|
||||
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
except Exception as e:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {e}")
|
||||
pass
|
||||
|
||||
return decorated_image_train_items
|
||||
|
||||
def __bucketize_images(self, prepared_train_data: list, batch_size=1, debug_level=0):
|
||||
"""
|
||||
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
|
||||
"""
|
||||
# TODO: this is not terribly efficient but at least linear time
|
||||
buckets = {}
|
||||
|
||||
for image_caption_pair in prepared_train_data:
|
||||
batch_size = self.batch_size
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
target_wh = image_caption_pair.target_wh
|
||||
|
||||
|
@ -215,27 +153,230 @@ class DataLoaderMultiAspect():
|
|||
return image_caption_pairs
|
||||
|
||||
@staticmethod
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
multiply = 1
|
||||
multiply_path = os.path.join(recurse_root, "multiply.txt")
|
||||
if os.path.exists(multiply_path):
|
||||
def unzip_all(path):
|
||||
try:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
logging.info(f"Unzipping {file}")
|
||||
with zipfile.ZipFile(path, 'r') as zip_ref:
|
||||
zip_ref.extractall(path)
|
||||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
|
||||
rating_overall_sum: float = 0.0
|
||||
ratings_summed: list[float] = []
|
||||
for image in self.prepared_train_data:
|
||||
rating_overall_sum += image.caption.rating()
|
||||
ratings_summed.append(rating_overall_sum)
|
||||
|
||||
return rating_overall_sum, ratings_summed
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
caption = fallback_caption
|
||||
pass
|
||||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
|
||||
with open(file_path, "r") as stream:
|
||||
try:
|
||||
with open(multiply_path, encoding='utf-8', mode='r') as f:
|
||||
multiply = int(float(f.read().strip()))
|
||||
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
|
||||
except:
|
||||
logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
|
||||
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
|
||||
return fallback_caption
|
||||
|
||||
@staticmethod
|
||||
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
|
||||
"""
|
||||
Splits a string by "," into the main prompt and additional tags with equal weights
|
||||
"""
|
||||
split_caption = caption_string.split(",")
|
||||
main_prompt = split_caption.pop(0).strip()
|
||||
tags = []
|
||||
for tag in split_caption:
|
||||
tags.append(tag.strip())
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Create ImageTrainItem objects with metadata for hydration later
|
||||
"""
|
||||
decorated_image_train_items = []
|
||||
|
||||
if not self.has_scanned:
|
||||
undersized_images = []
|
||||
|
||||
multipliers = {}
|
||||
skip_folders = []
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
|
||||
|
||||
file_path_without_ext = os.path.splitext(pathname)[0]
|
||||
yaml_file_path = file_path_without_ext + ".yaml"
|
||||
txt_file_path = file_path_without_ext + ".txt"
|
||||
caption_file_path = file_path_without_ext + ".caption"
|
||||
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
try:
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
#print(current_dir, multiply_txt_path)
|
||||
if os.path.exists(multiply_txt_path):
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
|
||||
else:
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
skip_folders.append(current_dir)
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
if os.path.exists(yaml_file_path):
|
||||
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
|
||||
elif os.path.exists(txt_file_path):
|
||||
caption = self.__read_caption_from_file(txt_file_path, caption)
|
||||
elif os.path.exists(caption_file_path):
|
||||
caption = self.__read_caption_from_file(caption_file_path, caption)
|
||||
|
||||
try:
|
||||
image = Image.open(pathname)
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
if not self.has_scanned:
|
||||
if width * height < target_wh[0] * target_wh[1]:
|
||||
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
|
||||
|
||||
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
|
||||
caption=caption,
|
||||
target_wh=target_wh,
|
||||
pathname=pathname,
|
||||
flip_p=flip_p,
|
||||
multiplier=multipliers[current_dir],
|
||||
)
|
||||
|
||||
cur_file_multiplier = multipliers[current_dir]
|
||||
|
||||
while cur_file_multiplier >= 1.0:
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
cur_file_multiplier -= 1
|
||||
|
||||
if cur_file_multiplier > 0:
|
||||
if randomizer.random() < cur_file_multiplier:
|
||||
decorated_image_train_items.append(image_train_item)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {e}")
|
||||
pass
|
||||
|
||||
if not self.has_scanned:
|
||||
self.has_scanned = True
|
||||
if len(undersized_images) > 0:
|
||||
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for undersized_image in undersized_images:
|
||||
undersized_images_file.write(f"{undersized_image}\n")
|
||||
|
||||
print (f" * DLMA: {len(decorated_image_train_items)} images loaded from {len(image_paths)} files")
|
||||
|
||||
return decorated_image_train_items
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Picks a random subset of all images
|
||||
- The size of the subset is limited by dropout_faction
|
||||
- The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance
|
||||
:param dropout_fraction: must be between 0.0 and 1.0
|
||||
:param picker: seeded random picker
|
||||
:return: list of picked ImageTrainItem
|
||||
"""
|
||||
|
||||
prepared_train_data = self.prepared_train_data.copy()
|
||||
ratings_summed = self.ratings_summed.copy()
|
||||
rating_overall_sum = self.rating_overall_sum
|
||||
|
||||
num_images = len(prepared_train_data)
|
||||
num_images_to_pick = math.ceil(num_images * dropout_fraction)
|
||||
num_images_to_pick = max(min(num_images_to_pick, num_images), 0)
|
||||
|
||||
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
|
||||
|
||||
picked_images: list[ImageTrainItem] = []
|
||||
while num_images_to_pick > len(picked_images):
|
||||
# find random sample in dataset
|
||||
point = picker.uniform(0.0, rating_overall_sum)
|
||||
pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 )
|
||||
|
||||
# pick random sample
|
||||
picked_image = prepared_train_data[pos]
|
||||
picked_images.append(picked_image)
|
||||
|
||||
# kick picked item out of data set to not pick it again
|
||||
rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0)
|
||||
ratings_summed.pop(pos)
|
||||
prepared_train_data.pop(pos)
|
||||
|
||||
return picked_images
|
||||
|
||||
@staticmethod
|
||||
def __recurse_data_root(self, recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
# add image multiplyrepeats number of times
|
||||
for _ in range(multiply):
|
||||
self.image_paths.append(current)
|
||||
self.image_paths.append(current)
|
||||
|
||||
sub_dirs = []
|
||||
|
||||
|
|
|
@ -51,6 +51,8 @@ class EveryDreamBatch(Dataset):
|
|||
retain_contrast=False,
|
||||
write_schedule=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
@ -66,6 +68,8 @@ class EveryDreamBatch(Dataset):
|
|||
self.write_schedule = write_schedule
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999)
|
||||
|
@ -80,18 +84,16 @@ class EveryDreamBatch(Dataset):
|
|||
resolution=resolution,
|
||||
log_folder=self.log_folder,
|
||||
)
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
|
||||
self.num_images = len(self.image_train_items)
|
||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
||||
|
||||
self._length = self.num_images
|
||||
num_images = len(self.image_train_items)
|
||||
|
||||
logging.info(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}")
|
||||
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
if self.write_schedule:
|
||||
self.write_batch_schedule(0)
|
||||
self.__write_batch_schedule(0)
|
||||
|
||||
def write_batch_schedule(self, epoch_n):
|
||||
def __write_batch_schedule(self, epoch_n):
|
||||
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||
for i in range(len(self.image_train_items)):
|
||||
try:
|
||||
|
@ -102,19 +104,23 @@ class EveryDreamBatch(Dataset):
|
|||
def get_runts():
|
||||
return dls.shared_dataloader.runts
|
||||
|
||||
def shuffle(self, epoch_n):
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
if dls.shared_dataloader:
|
||||
dls.shared_dataloader.shuffle()
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
if self.rated_dataset:
|
||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
|
||||
else:
|
||||
raise Exception("No dataloader singleton to shuffle")
|
||||
|
||||
if self.write_schedule:
|
||||
self.write_batch_schedule(epoch_n)
|
||||
self.__write_batch_schedule(epoch_n + 1)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
return len(self.image_train_items)
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
|
|
|
@ -31,7 +31,7 @@ class ImageCaption:
|
|||
Represents the various parts of an image caption
|
||||
"""
|
||||
|
||||
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
"""
|
||||
:param main_prompt: The part of the caption which should always be included
|
||||
:param tags: list of tags to pick from to fill the caption
|
||||
|
@ -40,6 +40,7 @@ class ImageCaption:
|
|||
:param use_weights: if ture, weights are considered when shuffling tags
|
||||
"""
|
||||
self.__main_prompt = main_prompt
|
||||
self.__rating = rating
|
||||
self.__tags = tags
|
||||
self.__tag_weights = tag_weights
|
||||
self.__max_target_length = max_target_length
|
||||
|
@ -50,6 +51,9 @@ class ImageCaption:
|
|||
if use_weights and len(tag_weights) > len(tags):
|
||||
self.__tag_weights = tag_weights[:len(tags)]
|
||||
|
||||
def rating(self) -> float:
|
||||
return self.__rating
|
||||
|
||||
def get_shuffled_caption(self, seed: int) -> str:
|
||||
"""
|
||||
returns the caption a string with a random selection of the tags in random order
|
||||
|
@ -97,22 +101,23 @@ class ImageCaption:
|
|||
return ", ".join(tags)
|
||||
|
||||
|
||||
class ImageTrainItem():
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
image: PIL.Image
|
||||
identifier: caption,
|
||||
target_aspect: (width, height),
|
||||
pathname: path to image file
|
||||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
self.caption = caption
|
||||
self.target_wh = target_wh
|
||||
self.pathname = pathname
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
self.cropped_img = None
|
||||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
|
||||
if image is None:
|
||||
self.image = []
|
||||
|
|
|
@ -108,15 +108,17 @@ Some experimentation shows if you already have batch size in the 6-8 range than
|
|||
|
||||
## Gradient checkpointing
|
||||
|
||||
While traditionally used to reduce VRAM for smaller GPUs, gradient checkpointing can offer a higher batch size and/or higher resolution within whatever VRAM you have, so it may be useful even on a 24GB+ GPU.
|
||||
This is mostly useful to reduce VRAM for smaller GPUs, and together with AdamW 8 bit and AMP mode can enable <12GB GPU training.
|
||||
|
||||
Gradient checkpointing can also offer a higher batch size and/or higher resolution within whatever VRAM you have, so it may be useful even on a 24GB+ GPU if you specifically want to run a very large batch size. The other option is using gradient accumulation instead.
|
||||
|
||||
--gradient_checkpointing ^
|
||||
|
||||
This drastically reduces VRAM (by many GB) and will allow quite a larger batch size or resolution, for example, 13-14 instead of 7-8 on a 24GB card using 512 training resolution.
|
||||
While gradient checkpointing reduces performance, the ability to run a higher batch size brings performance back fairly close to without it.
|
||||
|
||||
While gradient checkpointing reduces performance, the ability to run a higher batch size brings performance back fairly close to without it. My personal tests show a 25% performance hit simply turning on gradient checkpointing on a 3090 (batch 7, 512), but almost all of that is made up by the ability to use a larger batch size (up to 14). You may NOT want to use a batch size as large as 13-14, or you may find you need to tweak learning rate all over again to find the right balance.
|
||||
You may NOT want to use a batch size as large as 13-14, or you may find you need to tweak learning rate all over again to find the right balance. Generally I would not turn it on for a 24GB GPU training at <640 resolution.
|
||||
|
||||
This probably IS a good idea for training at higher resolutions. Balancing this toggle, resolution, batch_size, and grad_accum will take some experimentation, but you might try using this with 768+ resolutions, grad_accum 3-4, and then as high of a batch size as you can get to work without crashing, while adjusting LR with respect to your (batch_size * grad_accum) value.
|
||||
This probably IS a good idea for training at higher resolutions and allows >768 training on 24GB GPUs. Balancing this toggle, resolution, and batch_size will take a few quick experiments to see what you can run safely.
|
||||
|
||||
## Flip_p
|
||||
|
||||
|
|
|
@ -16,6 +16,17 @@ You may wish to consider adding "sd1" or "sd2v" or similar to remember what the
|
|||
|
||||
--project_name "jets_sd21768v" ^
|
||||
|
||||
|
||||
## Stuff you probably want on
|
||||
|
||||
--amp
|
||||
|
||||
Enables automatic mixed precision. Greatly improved training speed and can help a bit with VRAM use. [Torch](https://pytorch.org/docs/stable/amp.html) will automatically use FP16 precision for specific model components where FP16 is sufficient precision, and FP32 otherwise. This also enables xformers to work with the SD1.x attention head schema.
|
||||
|
||||
--useadam8bit
|
||||
|
||||
Uses [Tim Dettmer's reduced precision AdamW 8 Bit optimizer](https://github.com/TimDettmers/bitsandbytes). This seems to have no noticeable impact on quality but is considerable faster and more VRAM efficient. See more below in AdamW vs AdamW 8bit.
|
||||
|
||||
## Epochs
|
||||
|
||||
EveryDream 2.0 has done away with repeats and instead you should set your max_epochs. Changing epochs has the same effect as changing repeats in DreamBooth or EveryDream1. For example, if you had 50 repeats and 5 epochs, you would now set max_epochs to 250 (50x5=250). This is a bit more intuitive as there is no more double meaning for epochs and repeats.
|
||||
|
@ -28,6 +39,16 @@ With more training data for your subjects and concepts, you can slowly scale thi
|
|||
|
||||
With less training data, this value should be higher, because more repetition on the images is needed to learn.
|
||||
|
||||
## Resolution
|
||||
|
||||
The resolution for training. All buckets for multiaspect will be based on the total pixel count of your resolution squared.
|
||||
|
||||
--resolution 768
|
||||
|
||||
Current supported resolutions can be printed by running the trainer without any arugments.
|
||||
|
||||
python train.py
|
||||
|
||||
## Save interval for checkpoints
|
||||
|
||||
While EveryDream 1.0 saved a checkpoint every epoch, this is no longer the case as it would produce too many files as "repeats" are removed in favor of just using epochs instead. To balance the fact EveryDream users are sometimes training small datasets and sometimes huge datasets, you can now set the interval at which checkpoints are saved. The default is 30 minutes, but you can change it to whatever you want.
|
||||
|
@ -76,16 +97,6 @@ At this time, ED2.0 supports constant or cosine scheduler.
|
|||
|
||||
The constant scheduler is the default and keeps your LR set to the value you set in the command line. That's really it for constant! I recommend sticking with it until you are comfortable with general training. More info in the [Advanced Tweaking](ATWEAKING.md) document.
|
||||
|
||||
## AdamW vs AdamW 8bit
|
||||
|
||||
The AdamW optimizer is the default and what was used by EveryDream 1.0. It's a good optimizer for Stable Diffusion and appears to be what was used to train SD itself.
|
||||
|
||||
AdamW 8bit is quite a bit faster and uses less VRAM while still having the same basic behavior. I currently **recommend** using it for most cases as it seems worth a potential slight reduction in quality for a *significant speed boost and lower VRAM cost*.
|
||||
|
||||
--useadam8bit ^
|
||||
|
||||
This may become a default in the future, and replaced with an option to use standard AdamW instead. For now, it's an option, *but I recommend always using it.*
|
||||
|
||||
## Sampling
|
||||
|
||||
You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt. Or you can point to another file with --sample_prompts.
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
12
train.json
12
train.json
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"amp": false,
|
||||
"amp": true,
|
||||
"batch_size": 10,
|
||||
"ckpt_every_n_minutes": null,
|
||||
"clip_grad_norm": null,
|
||||
|
@ -8,7 +8,6 @@
|
|||
"data_root": "X:\\my_project_data\\project_abc",
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
"ed1_mode": true,
|
||||
"flip_p": 0.0,
|
||||
"gpuid": 0,
|
||||
"gradient_checkpointing": true,
|
||||
|
@ -16,11 +15,12 @@
|
|||
"logdir": "logs",
|
||||
"log_step": 25,
|
||||
"lowvram": false,
|
||||
"lr": 3.5e-06,
|
||||
"lr": 1.5e-06,
|
||||
"lr_decay_steps": 0,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"project_name": "project_abc",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "sd_v1-5_vae",
|
||||
|
@ -34,5 +34,7 @@
|
|||
"shuffle_tags": false,
|
||||
"useadam8bit": true,
|
||||
"wandb": false,
|
||||
"write_schedule": false
|
||||
}
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_rate": 50
|
||||
}
|
||||
|
|
292
train.py
292
train.py
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
Copyright [2022-2023] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
|
@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
@ -23,9 +24,10 @@ import time
|
|||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import shutil
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from colorama import Fore, Style, Cursor
|
||||
|
@ -47,30 +49,14 @@ from accelerate.utils import set_seed
|
|||
import wandb
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import keyboard
|
||||
|
||||
from data.every_dream import EveryDreamBatch
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
|
||||
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
# def is_notebook() -> bool:
|
||||
# try:
|
||||
# from IPython import get_ipython
|
||||
# shell = get_ipython().__class__.__name__
|
||||
# if shell == 'ZMQInteractiveShell':
|
||||
# return True # Jupyter notebook or qtconsole
|
||||
# elif shell == 'TerminalInteractiveShell':
|
||||
# return False # Terminal running IPython
|
||||
# else:
|
||||
# return False # Other type (?)
|
||||
# except NameError:
|
||||
# return False # Probably standard Python interpreter
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
|
@ -90,19 +76,19 @@ def convert_to_hf(ckpt_path):
|
|||
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
||||
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
||||
except:
|
||||
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
|
||||
logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.")
|
||||
exit()
|
||||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||
return hf_cache
|
||||
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn, yaml
|
||||
elif os.path.isdir(hf_cache):
|
||||
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
|
||||
return hf_cache
|
||||
is_sd1attn, yaml = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn, yaml
|
||||
else:
|
||||
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
|
||||
return ckpt_path
|
||||
is_sd1attn, yaml = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn, yaml
|
||||
|
||||
def setup_local_logger(args):
|
||||
"""
|
||||
|
@ -225,15 +211,14 @@ def setup_args(args):
|
|||
Sets defaults for missing args (possible if missing from json config)
|
||||
Forces some args to be set based on others for compatibility reasons
|
||||
"""
|
||||
if args.disable_unet_training and args.disable_textenc_training:
|
||||
raise ValueError("Both unet and textenc are disabled, nothing to train")
|
||||
|
||||
if args.resume_ckpt == "findlast":
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if args.ed1_mode and not args.disable_xformers:
|
||||
args.disable_xformers = True
|
||||
logging.info(" ED1 mode: Overiding disable_xformers to True")
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
|
@ -241,7 +226,7 @@ def setup_args(args):
|
|||
args.shuffle_tags = False
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
args.ckpt_every_n_minutes = 20
|
||||
|
@ -251,7 +236,7 @@ def setup_args(args):
|
|||
|
||||
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
|
||||
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
||||
|
||||
|
||||
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
||||
|
@ -272,8 +257,35 @@ def setup_args(args):
|
|||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
if args.rated_dataset:
|
||||
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
||||
|
||||
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
||||
|
||||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 250 or (epoch >= 4 and step == 1):
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 500 or (epoch >= 8 and step == 1):
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000 or (epoch >= 10 and step == 1):
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 3000 or (epoch >= 20 and step == 1):
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
|
@ -284,9 +296,10 @@ def main(args):
|
|||
if args.notebook:
|
||||
from tqdm.notebook import tqdm
|
||||
else:
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
|
@ -294,12 +307,12 @@ def main(args):
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
logging.info(f"Logging to {log_folder}")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -320,17 +333,24 @@ def main(args):
|
|||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
# if self.save_optimizer_flag:
|
||||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
|
@ -391,8 +411,10 @@ def main(args):
|
|||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
logging.info(f"Generating samples gs:{gs}, for {prompts}")
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
gen = torch.Generator(device="cuda").manual_seed(seed)
|
||||
gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
i = 0
|
||||
for prompt in prompts:
|
||||
|
@ -437,11 +459,11 @@ def main(args):
|
|||
|
||||
try:
|
||||
# first try to download from HF
|
||||
model_root_folder = try_download_model_from_hf(repo_id=args.resume_ckpt,
|
||||
subfolder=args.hf_repo_subfolder)
|
||||
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
|
||||
subfolder=args.hf_repo_subfolder)
|
||||
# if that doesn't work, try a local folder
|
||||
if model_root_folder is None:
|
||||
model_root_folder = convert_to_hf(args.resume_ckpt)
|
||||
model_root_folder, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
|
@ -456,39 +478,46 @@ def main(args):
|
|||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.ed1_mode and not args.lowvram:
|
||||
unet.set_attention_slice(4)
|
||||
|
||||
if not args.disable_xformers and is_xformers_available():
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
except Exception as ex:
|
||||
logging.warning("failed to load xformers, continuing without it")
|
||||
pass
|
||||
if not args.disable_xformers:
|
||||
if (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
except Exception as ex:
|
||||
logging.warning("failed to load xformers, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
pass
|
||||
else:
|
||||
logging.info("xformers not available or disabled")
|
||||
logging.info("xformers disabled, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
if args.disable_textenc_training and args.amp:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float16)
|
||||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8 if not args.amp else 1e-8
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
epsilon = 2e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -507,6 +536,8 @@ def main(args):
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
train_batch = EveryDreamBatch(
|
||||
data_root=args.data_root,
|
||||
flip_p=args.flip_p,
|
||||
|
@ -519,6 +550,8 @@ def main(args):
|
|||
log_folder=log_folder,
|
||||
write_schedule=args.write_schedule,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
||||
torch.cuda.benchmark = False
|
||||
|
@ -543,11 +576,8 @@ def main(args):
|
|||
sample_prompts.append(line.strip())
|
||||
|
||||
|
||||
if False: #args.wandb is not None and args.wandb: # not yet supported
|
||||
log_writer = wandb.init(project="EveryDream2FineTunes",
|
||||
name=args.project_name,
|
||||
dir=log_folder,
|
||||
)
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
else:
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
|
@ -562,14 +592,13 @@ def main(args):
|
|||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
||||
"""
|
||||
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
|
||||
print(f" (C) 2022 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
|
||||
print(f" (C) 2022-2023 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
|
||||
print()
|
||||
print("** Trainer Starting **")
|
||||
|
||||
|
@ -605,7 +634,6 @@ def main(args):
|
|||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
|
@ -635,7 +663,7 @@ def main(args):
|
|||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
unet.train()
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||
|
||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||
|
@ -646,15 +674,19 @@ def main(args):
|
|||
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
||||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
|
||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
epoch_times = []
|
||||
|
||||
global global_step
|
||||
|
@ -664,28 +696,38 @@ def main(args):
|
|||
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
|
||||
|
||||
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
|
||||
|
||||
if args.amp:
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
#enabled=False,
|
||||
enabled=True,
|
||||
init_scale=1024.0,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=50,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
|
||||
|
||||
loss_log_step = []
|
||||
|
||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
try:
|
||||
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
|
||||
# pixel_values = pixel_values.to(unet.device)
|
||||
# with autocast(enabled=args.amp):
|
||||
# latents = vae.encode(pixel_values, return_dict=False)
|
||||
# latents = latents[0].sample() * 0.18215
|
||||
# noise = torch.randn_like(latents)
|
||||
# bsz = latents.shape[0]
|
||||
# timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
# timesteps = timesteps.long()
|
||||
# noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
# cuda_caption = torch.linspace(100,177, steps=77, dtype=int).to(text_encoder.device)
|
||||
# encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True).last_hidden_state
|
||||
# with autocast(enabled=args.amp):
|
||||
# model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
# # discard the grads, just want to pin memory
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
steps_pbar.reset()
|
||||
images_per_sec_log_step = []
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
|
@ -722,23 +764,22 @@ def main(args):
|
|||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
#with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
#del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if args.amp:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
grad_scale = batch["runt_size"] / args.batch_size
|
||||
|
@ -752,11 +793,8 @@ def main(args):
|
|||
param.grad *= grad_scale
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
if args.amp:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
lr_scheduler.step()
|
||||
|
@ -778,6 +816,7 @@ def main(args):
|
|||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
images_per_sec_log_step = []
|
||||
|
@ -787,7 +826,7 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
|
@ -809,33 +848,37 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
||||
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
|
||||
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
|
||||
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if epoch < args.max_epochs - 1:
|
||||
train_batch.shuffle(epoch_n=epoch+1)
|
||||
|
||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -845,7 +888,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
@ -858,17 +901,28 @@ def update_old_args(t_args):
|
|||
Update old args to new args to deal with json config loading and missing args for compatibility
|
||||
"""
|
||||
if not hasattr(t_args, "shuffle_tags"):
|
||||
print(f" Config json is missing 'shuffle_tags'")
|
||||
print(f" Config json is missing 'shuffle_tags' flag")
|
||||
t_args.__dict__["shuffle_tags"] = False
|
||||
if not hasattr(t_args, "save_full_precision"):
|
||||
print(f" Config json is missing 'save_full_precision'")
|
||||
print(f" Config json is missing 'save_full_precision' flag")
|
||||
t_args.__dict__["save_full_precision"] = False
|
||||
if not hasattr(t_args, "notebook"):
|
||||
print(f" Config json is missing 'notebook'")
|
||||
print(f" Config json is missing 'notebook' flag")
|
||||
t_args.__dict__["notebook"] = False
|
||||
if not hasattr(t_args, "disable_unet_training"):
|
||||
print(f" Config json is missing 'disable_unet_training' flag")
|
||||
t_args.__dict__["disable_unet_training"] = False
|
||||
if not hasattr(t_args, "rated_dataset"):
|
||||
print(f" Config json is missing 'rated_dataset' flag")
|
||||
t_args.__dict__["rated_dataset"] = False
|
||||
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
|
||||
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
|
||||
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||
supported_precisions = ['fp16', 'fp32']
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
@ -879,22 +933,22 @@ if __name__ == "__main__":
|
|||
t_args = argparse.Namespace()
|
||||
t_args.__dict__.update(json.load(f))
|
||||
update_old_args(t_args) # update args to support older configs
|
||||
print(t_args.__dict__)
|
||||
print(f" args: \n{t_args.__dict__}")
|
||||
args = argparser.parse_args(namespace=t_args)
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
|
@ -907,6 +961,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||
|
@ -914,6 +969,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
|
@ -921,9 +977,9 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
||||
args = argparser.parse_args()
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -0,0 +1,956 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import signal
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import gc
|
||||
import random
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from colorama import Fore, Style, Cursor
|
||||
import numpy as np
|
||||
import itertools
|
||||
import torch
|
||||
import datetime
|
||||
import json
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler
|
||||
#from diffusers.models import AttentionBlock
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
#from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
import wandb
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import keyboard
|
||||
|
||||
from data.every_dream import EveryDreamBatch
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.gpu import GPU
|
||||
forstepTime = time.time()
|
||||
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
# def is_notebook() -> bool:
|
||||
# try:
|
||||
# from IPython import get_ipython
|
||||
# shell = get_ipython().__class__.__name__
|
||||
# if shell == 'ZMQInteractiveShell':
|
||||
# return True # Jupyter notebook or qtconsole
|
||||
# elif shell == 'TerminalInteractiveShell':
|
||||
# return False # Terminal running IPython
|
||||
# else:
|
||||
# return False # Other type (?)
|
||||
# except NameError:
|
||||
# return False # Probably standard Python interpreter
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
"""
|
||||
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||
|
||||
def convert_to_hf(ckpt_path):
|
||||
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
from utils.patch_unet import patch_unet
|
||||
|
||||
if os.path.isfile(ckpt_path):
|
||||
if not os.path.exists(hf_cache):
|
||||
os.makedirs(hf_cache)
|
||||
logging.info(f"Converting {ckpt_path} to Diffusers format")
|
||||
try:
|
||||
import utils.convert_original_stable_diffusion_to_diffusers as convert
|
||||
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
|
||||
except:
|
||||
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
|
||||
exit()
|
||||
else:
|
||||
logging.info(f"Found cached checkpoint at {hf_cache}")
|
||||
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
elif os.path.isdir(hf_cache):
|
||||
is_sd1attn = patch_unet(hf_cache)
|
||||
return hf_cache, is_sd1attn
|
||||
else:
|
||||
is_sd1attn = patch_unet(ckpt_path)
|
||||
return ckpt_path, is_sd1attn
|
||||
|
||||
def setup_local_logger(args):
|
||||
"""
|
||||
configures logger with file and console logging, logs args, and returns the datestamp
|
||||
"""
|
||||
log_path = args.logdir
|
||||
|
||||
if not os.path.exists(log_path):
|
||||
os.makedirs(log_path)
|
||||
|
||||
json_config = json.dumps(vars(args), indent=2)
|
||||
datetimestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
with open(os.path.join(log_path, f"{args.project_name}-{datetimestamp}_cfg.json"), "w") as f:
|
||||
f.write(f"{json_config}")
|
||||
|
||||
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
|
||||
print(f" logging to {logfilename}")
|
||||
logging.basicConfig(filename=logfilename,
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Saves the optimizer state
|
||||
"""
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state
|
||||
"""
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
|
||||
def get_gpu_memory(nvsmi):
|
||||
"""
|
||||
returns the gpu memory usage
|
||||
"""
|
||||
gpu_query = nvsmi.DeviceQuery('memory.used, memory.total')
|
||||
gpu_used_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['used'])
|
||||
gpu_total_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['total'])
|
||||
return gpu_used_mem, gpu_total_mem
|
||||
|
||||
def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||
"""
|
||||
updates the vram usage for the epoch
|
||||
"""
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
print(f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step} | Elapsed : {time.time() - forstepTime}s")
|
||||
|
||||
|
||||
def set_args_12gb(args):
|
||||
logging.info(" Setting args to 12GB mode")
|
||||
if not args.gradient_checkpointing:
|
||||
logging.info(" - Overiding gradient checkpointing to True")
|
||||
args.gradient_checkpointing = True
|
||||
if args.batch_size != 1:
|
||||
logging.info(" - Overiding batch size to 1")
|
||||
args.batch_size = 1
|
||||
# if args.grad_accum != 1:
|
||||
# logging.info(" Overiding grad accum to 1")
|
||||
args.grad_accum = 1
|
||||
if args.resolution > 512:
|
||||
logging.info(" - Overiding resolution to 512")
|
||||
args.resolution = 512
|
||||
if not args.useadam8bit:
|
||||
logging.info(" - Overiding adam8bit to True")
|
||||
args.useadam8bit = True
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
"""
|
||||
Finds the last checkpoint in the logdir, recursively
|
||||
"""
|
||||
last_ckpt = None
|
||||
last_date = None
|
||||
|
||||
for root, dirs, files in os.walk(logdir):
|
||||
for file in files:
|
||||
if os.path.basename(file) == "model_index.json":
|
||||
curr_date = os.path.getmtime(os.path.join(root,file))
|
||||
|
||||
if last_date is None or curr_date > last_date:
|
||||
last_date = curr_date
|
||||
last_ckpt = root
|
||||
|
||||
assert last_ckpt, f"Could not find last checkpoint in logdir: {logdir}"
|
||||
assert "errored" not in last_ckpt, f"Found last checkpoint: {last_ckpt}, but it was errored, cancelling"
|
||||
|
||||
print(f" {Fore.LIGHTCYAN_EX}Found last checkpoint: {last_ckpt}, resuming{Style.RESET_ALL}")
|
||||
|
||||
return last_ckpt
|
||||
|
||||
def setup_args(args):
|
||||
"""
|
||||
Sets defaults for missing args (possible if missing from json config)
|
||||
Forces some args to be set based on others for compatibility reasons
|
||||
"""
|
||||
if args.disable_unet_training and args.disable_textenc_training:
|
||||
raise ValueError("Both unet and textenc are disabled, nothing to train")
|
||||
|
||||
if args.resume_ckpt == "findlast":
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
if not args.shuffle_tags:
|
||||
args.shuffle_tags = False
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
args.ckpt_every_n_minutes = 20
|
||||
|
||||
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
|
||||
args.ckpt_every_n_minutes = _VERY_LARGE_NUMBER
|
||||
|
||||
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
|
||||
args.save_every_n_epochs = _VERY_LARGE_NUMBER
|
||||
|
||||
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
|
||||
|
||||
if args.cond_dropout > 0.26:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
|
||||
|
||||
if args.grad_accum > 1:
|
||||
logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}")
|
||||
|
||||
total_batch_size = args.batch_size * args.grad_accum
|
||||
|
||||
if args.scale_lr is not None and args.scale_lr:
|
||||
tmp_lr = args.lr
|
||||
args.lr = args.lr * (total_batch_size**0.55)
|
||||
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
|
||||
|
||||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
if args.rated_dataset:
|
||||
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
|
||||
|
||||
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
|
||||
|
||||
return args
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Main entry point
|
||||
"""
|
||||
log_time = setup_local_logger(args)
|
||||
args = setup_args(args)
|
||||
|
||||
if args.notebook:
|
||||
from tqdm.notebook import tqdm
|
||||
else:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
gpu = GPU()
|
||||
device = torch.device(f"cuda:{args.gpuid}")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
logging.info(f"Logging to {log_folder}")
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
global global_step
|
||||
if global_step is None or global_step == 0:
|
||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||
return
|
||||
logging.info(f" * Saving diffusers model to {save_path}")
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
|
||||
# if self.save_optimizer_flag:
|
||||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae):
|
||||
"""
|
||||
creates a pipeline for SD inference
|
||||
"""
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
return pipe
|
||||
|
||||
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
|
||||
"""
|
||||
generates a single sample at a given cfg scale and saves it to disk
|
||||
"""
|
||||
with torch.no_grad(), autocast():
|
||||
image = pipe(prompt,
|
||||
num_inference_steps=30,
|
||||
num_images_per_prompt=1,
|
||||
guidance_scale=cfg,
|
||||
generator=gen,
|
||||
height=resolution,
|
||||
width=resolution,
|
||||
).images[0]
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
try:
|
||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
print_msg = f"cfg:{cfg:.1f}"
|
||||
|
||||
l, t, r, b = draw.textbbox(xy=(0,0), text=print_msg, font=font)
|
||||
text_width = r - l
|
||||
text_height = b - t
|
||||
|
||||
x = float(image.width - text_width - 10)
|
||||
y = float(image.height - text_height - 10)
|
||||
|
||||
draw.rectangle((x, y, image.width, image.height), fill="white")
|
||||
draw.text((x, y), print_msg, fill="black", font=font)
|
||||
del draw, font
|
||||
return image
|
||||
|
||||
def __generate_test_samples(pipe, prompts, gs, log_writer, log_folder, random_captions=False, resolution=512):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
logging.info(f"Generating samples gs:{gs}, for {prompts}")
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
i = 0
|
||||
for prompt in prompts:
|
||||
if prompt is None or len(prompt) < 2:
|
||||
#logging.warning("empty prompt in sample prompts, check your prompts file")
|
||||
continue
|
||||
images = []
|
||||
for cfg in [7.0, 4.0, 1.01]:
|
||||
image = __generate_sample(pipe, prompt, cfg, resolution=resolution, gen=gen)
|
||||
images.append(image)
|
||||
|
||||
width = 0
|
||||
height = 0
|
||||
for image in images:
|
||||
width += image.width
|
||||
height = max(height, image.height)
|
||||
|
||||
result = Image.new('RGB', (width, height))
|
||||
|
||||
x_offset = 0
|
||||
for image in images:
|
||||
result.paste(image, (x_offset, 0))
|
||||
x_offset += image.width
|
||||
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(prompt)
|
||||
f.write(f"\n seed: {seed}")
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if random_captions:
|
||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||
else:
|
||||
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
|
||||
i += 1
|
||||
|
||||
del result
|
||||
del tfimage
|
||||
del images
|
||||
|
||||
try:
|
||||
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
||||
except:
|
||||
logging.ERROR(" * Failed to load checkpoint *")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
logging.info("Enabled xformers")
|
||||
except Exception as ex:
|
||||
logging.warning("failed to load xformers, continuing without it")
|
||||
pass
|
||||
else:
|
||||
logging.info("xformers not available or disabled")
|
||||
|
||||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
if args.disable_textenc_training and args.amp:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float16)
|
||||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
epsilon = 2e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
|
||||
else:
|
||||
opt_class = torch.optim.AdamW
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
|
||||
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=betas,
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
train_batch = EveryDreamBatch(
|
||||
data_root=args.data_root,
|
||||
flip_p=args.flip_p,
|
||||
debug_level=1,
|
||||
batch_size=args.batch_size,
|
||||
conditional_dropout=args.cond_dropout,
|
||||
resolution=args.resolution,
|
||||
tokenizer=tokenizer,
|
||||
seed = seed,
|
||||
log_folder=log_folder,
|
||||
write_schedule=args.write_schedule,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
||||
torch.cuda.benchmark = False
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
sample_prompts = []
|
||||
with open(args.sample_prompts, "r") as f:
|
||||
for line in f:
|
||||
sample_prompts.append(line.strip())
|
||||
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, )
|
||||
else:
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
def log_args(log_writer, args):
|
||||
arglog = "args:\n"
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
arglog += f"{arg}={value}, "
|
||||
log_writer.add_text("config", arglog)
|
||||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
||||
"""
|
||||
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
|
||||
print(f" (C) 2022 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
|
||||
print()
|
||||
print("** Trainer Starting **")
|
||||
|
||||
global interrupted
|
||||
interrupted = False
|
||||
|
||||
def sigterm_handler(signum, frame):
|
||||
"""
|
||||
handles sigterm
|
||||
"""
|
||||
global interrupted
|
||||
if not interrupted:
|
||||
interrupted=True
|
||||
global global_step
|
||||
#TODO: save model on ctrl-c
|
||||
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
|
||||
print()
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
|
||||
signal.signal(signal.SIGINT, sigterm_handler)
|
||||
|
||||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
os.makedirs(f"{log_folder}/samples/")
|
||||
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_batch,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
unet.train() if not args.disable_unet_training else unet.eval()
|
||||
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
|
||||
|
||||
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
|
||||
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
|
||||
logging.info(f" vae device: {vae.device}, precision: {vae.dtype}, training: {vae.training}")
|
||||
logging.info(f" scheduler: {noise_scheduler.__class__}")
|
||||
|
||||
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
|
||||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
|
||||
#scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
enabled=args.amp,
|
||||
#enabled=True,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=1.8,
|
||||
backoff_factor=1.0/1.8,
|
||||
growth_interval=50,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
|
||||
# steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
|
||||
# steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
epoch_times = []
|
||||
|
||||
global global_step
|
||||
global_step = 0
|
||||
training_start_time = time.time()
|
||||
last_epoch_saved_time = training_start_time
|
||||
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
|
||||
|
||||
loss_log_step = []
|
||||
|
||||
try:
|
||||
for epoch in range(args.max_epochs):
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
images_per_sec_log_step = []
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
latents = vae.encode(pixel_values, return_dict=False)
|
||||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
||||
|
||||
#with autocast(enabled=args.amp):
|
||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||
|
||||
if args.clip_skip > 0:
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
||||
#if args.amp:
|
||||
scaler.scale(loss).backward()
|
||||
#else:
|
||||
# loss.backward()
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
grad_scale = batch["runt_size"] / args.batch_size
|
||||
with torch.no_grad(): # not required? just in case for now, needs more testing
|
||||
for param in unet.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad *= grad_scale
|
||||
if text_encoder.training:
|
||||
for param in text_encoder.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad *= grad_scale
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
# if args.amp:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
# else:
|
||||
# optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
lr_scheduler.step()
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
||||
steps_pbar.set_postfix({"loss/step": loss_step},{"gs": global_step})
|
||||
steps_pbar.update(1)
|
||||
|
||||
images_per_sec = args.batch_size / (time.time() - step_start_time)
|
||||
images_per_sec_log_step.append(images_per_sec)
|
||||
|
||||
loss_log_step.append(loss_step)
|
||||
loss_epoch.append(loss_step)
|
||||
|
||||
if (global_step + 1) % args.log_step == 0:
|
||||
curr_lr = lr_scheduler.get_last_lr()[0]
|
||||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
images_per_sec_log_step = []
|
||||
if args.amp:
|
||||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
|
||||
else:
|
||||
max_prompts = min(4,len(batch["captions"]))
|
||||
prompts=batch["captions"][:max_prompts]
|
||||
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||
|
||||
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
||||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
||||
if global_step == 500:
|
||||
scaler.set_growth_factor(1.4)
|
||||
scaler.set_backoff_factor(1/1.4)
|
||||
if global_step == 1000:
|
||||
scaler.set_growth_factor(1.2)
|
||||
scaler.set_backoff_factor(1/1.2)
|
||||
scaler.set_growth_interval(100)
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
||||
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
|
||||
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
|
||||
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if epoch < args.max_epochs - 1:
|
||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
# end of epoch
|
||||
|
||||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
logging.info(f"Total training time took {total_elapsed_time/60:.2f} minutes, total steps: {global_step}")
|
||||
logging.info(f"Average epoch time: {np.mean([t['time'] for t in epoch_times]):.2f} minutes")
|
||||
|
||||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
||||
|
||||
def update_old_args(t_args):
|
||||
"""
|
||||
Update old args to new args to deal with json config loading and missing args for compatibility
|
||||
"""
|
||||
if not hasattr(t_args, "shuffle_tags"):
|
||||
print(f" Config json is missing 'shuffle_tags' flag")
|
||||
t_args.__dict__["shuffle_tags"] = False
|
||||
if not hasattr(t_args, "save_full_precision"):
|
||||
print(f" Config json is missing 'save_full_precision' flag")
|
||||
t_args.__dict__["save_full_precision"] = False
|
||||
if not hasattr(t_args, "notebook"):
|
||||
print(f" Config json is missing 'notebook' flag")
|
||||
t_args.__dict__["notebook"] = False
|
||||
if not hasattr(t_args, "disable_unet_training"):
|
||||
print(f" Config json is missing 'disable_unet_training' flag")
|
||||
t_args.__dict__["disable_unet_training"] = False
|
||||
if not hasattr(t_args, "rated_dataset"):
|
||||
print(f" Config json is missing 'rated_dataset' flag")
|
||||
t_args.__dict__["rated_dataset"] = False
|
||||
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
|
||||
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
|
||||
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||
supported_precisions = ['fp16', 'fp32']
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
||||
if args.config is not None:
|
||||
print(f"Loading training config from {args.config}, all other command options will be ignored!")
|
||||
with open(args.config, 'rt') as f:
|
||||
t_args = argparse.Namespace()
|
||||
t_args.__dict__.update(json.load(f))
|
||||
update_old_args(t_args) # update args to support older configs
|
||||
print(f" args: \n{t_args.__dict__}")
|
||||
args = argparser.parse_args(namespace=t_args)
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
|
||||
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
||||
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
|
||||
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
|
||||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
|
||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
|
||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
|
||||
args, _ = argparser.parse_known_args()
|
||||
|
||||
main(args)
|
|
@ -1,11 +1,13 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import huggingface_hub
|
||||
from utils.patch_unet import patch_unet
|
||||
|
||||
|
||||
def try_download_model_from_hf(repo_id: str,
|
||||
subfolder: Optional[str]=None) -> Optional[str]:
|
||||
subfolder: Optional[str]=None) -> Tuple[Optional[str], Optional[bool], Optional[str]]:
|
||||
"""
|
||||
Attempts to download files from the following subfolders under the given repo id:
|
||||
"text_encoder", "vae", "unet", "scheduler", "tokenizer".
|
||||
|
@ -25,9 +27,11 @@ def try_download_model_from_hf(repo_id: str,
|
|||
# check if the model exists
|
||||
model_info = huggingface_hub.model_info(repo_id)
|
||||
if model_info is None:
|
||||
return None
|
||||
return None, None, None
|
||||
|
||||
model_subfolders = ["text_encoder", "vae", "unet", "scheduler", "tokenizer"]
|
||||
allow_patterns = [os.path.join(subfolder or '', f, "*") for f in model_subfolders]
|
||||
downloaded_folder = huggingface_hub.snapshot_download(repo_id=repo_id, allow_patterns=allow_patterns)
|
||||
return downloaded_folder
|
||||
|
||||
is_sd1_attn, yaml_path = patch_unet(downloaded_folder)
|
||||
return downloaded_folder, is_sd1_attn, yaml_path
|
||||
|
|
|
@ -16,24 +16,67 @@ limitations under the License.
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from colorama import Fore, Style
|
||||
|
||||
class LogWrapper(object):
|
||||
from tensorboard import SummaryWriter
|
||||
import wandb
|
||||
|
||||
class LogWrapper():
|
||||
"""
|
||||
singleton for logging
|
||||
"""
|
||||
def __init__(self, log_dir, project_name):
|
||||
self.log_dir = log_dir
|
||||
def __init__(self, args, wandb=False):
|
||||
self.logdir = args.logdir
|
||||
self.wandb = wandb
|
||||
|
||||
if wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True)
|
||||
else:
|
||||
self.log_writer = SummaryWriter(log_dir=args.logdir,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
start_time = time.strftime("%Y%m%d-%H%M")
|
||||
self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt")
|
||||
log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt")
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
self.logger.addHandler(console)
|
||||
|
||||
file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False)
|
||||
file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False)
|
||||
self.logger.addHandler(file)
|
||||
|
||||
def __call__(self):
|
||||
return self.logger
|
||||
def add_image():
|
||||
"""
|
||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||
else:
|
||||
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_scalar(self, tag: str, img_tensor: float, global_step: int):
|
||||
if self.wandb:
|
||||
wandb.log({tag: img_tensor}, step=global_step)
|
||||
else:
|
||||
self.log_writer.add_image(tag, img_tensor, global_step)
|
||||
|
||||
def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
||||
"""
|
||||
updates the vram usage for the epoch
|
||||
"""
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
self.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
|
@ -17,7 +17,7 @@ import os
|
|||
import json
|
||||
import logging
|
||||
|
||||
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
||||
def patch_unet(ckpt_path):
|
||||
"""
|
||||
Patch the UNet to use updated attention heads for xformers support in FP32
|
||||
"""
|
||||
|
@ -25,15 +25,27 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
|
|||
with open(unet_cfg_path, "r") as f:
|
||||
unet_cfg = json.load(f)
|
||||
|
||||
scheduler_cfg_path = os.path.join(ckpt_path, "scheduler", "scheduler_config.json")
|
||||
with open(scheduler_cfg_path, "r") as f:
|
||||
scheduler_cfg = json.load(f)
|
||||
|
||||
if force_sd1attn:
|
||||
if low_vram:
|
||||
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
|
||||
else:
|
||||
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
|
||||
else:
|
||||
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
|
||||
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
|
||||
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
|
||||
|
||||
prediction_type = scheduler_cfg["prediction_type"]
|
||||
|
||||
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
|
||||
with open(unet_cfg_path, "w") as f:
|
||||
json.dump(unet_cfg, f, indent=2)
|
||||
|
||||
yaml = ''
|
||||
if prediction_type in ["v_prediction","v-prediction"] and not is_sd1attn:
|
||||
yaml = "v2-inference-v.yaml"
|
||||
elif prediction_type == "epsilon" and not is_sd1attn:
|
||||
yaml = "v2-inference.yaml"
|
||||
elif prediction_type == "epsilon" and is_sd1attn:
|
||||
yaml = "v1-inference.yaml"
|
||||
else:
|
||||
raise ValueError(f"Unknown model format for: {prediction_type} and attention_head_dim {unet_cfg['attention_head_dim']}")
|
||||
|
||||
logging.info(f"Inferred yaml: {yaml}, attn: {'sd1' if is_sd1attn else 'sd2'}, prediction_type: {prediction_type}")
|
||||
|
||||
return is_sd1attn, yaml
|
||||
|
|
Loading…
Reference in New Issue