Merge remote-tracking branch 'upstream/main' into hf_model_download

This commit is contained in:
Damian Stewart 2023-01-23 19:19:22 +01:00
commit d24dd681c0
14 changed files with 1984 additions and 303 deletions

View File

@ -17,6 +17,10 @@ Covers install, setup of base models, startning training, basic tweaking, and lo
Behind the scenes look at how the trainer handles multiaspect and crop jitter
### Tools repo
Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.
## Docs
[Setup and installation](doc/SETUP.md)

439
Train_Colab.ipynb Normal file
View File

@ -0,0 +1,439 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "blaLMSbkPHhG"
},
"source": [
"# EveryDream2 Colab Edition\n",
"\n",
"Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
"\n",
"And join the discord: https://discord.gg/uheqxU6sXN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "WsYIcz9HY9lx"
},
"outputs": [],
"source": [
"#@title # Install python 3.10 \n",
"import os\n",
"from IPython.display import clear_output\n",
"!wget https://github.com/korakot/kora/releases/download/v0.10/py310.sh\n",
"!bash ./py310.sh -b -f -p /usr/local\n",
"!python -m ipykernel install --name \"py310\" --user\n",
"clear_output()\n",
"os.kill(os.getpid(), 9)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "f2cdMtCt9Wb6"
},
"outputs": [],
"source": [
"#@title Verify python version, should be 3.10.something\n",
"!python --version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "d1di4EC6ygw1"
},
"outputs": [],
"source": [
"#@title Optional connect Gdrive\n",
"#@markdown # but strongly recommended\n",
"#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
"\n",
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"!mkdir -p /content/drive/MyDrive/everydreamlogs/ckpt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "hAuBbtSvGpau"
},
"outputs": [],
"source": [
"#@markdown # Install Dependencies\n",
"#@markdown This will take a couple minutes, be patient and watch the output for \"DONE!\"\n",
"from IPython.display import clear_output\n",
"from subprocess import getoutput\n",
"s = getoutput('nvidia-smi')\n",
"!pip install -q torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url \"https://download.pytorch.org/whl/cu116\"\n",
"!pip install -q transformers==4.25.1\n",
"!pip install -q diffusers[torch]==0.10.2\n",
"!pip install -q pynvml==11.4.1\n",
"!pip install -q bitsandbytes==0.35.0\n",
"!pip install -q ftfy==6.1.1\n",
"!pip install -q aiohttp==3.8.3\n",
"!pip install -q tensorboard>=2.11.0\n",
"!pip install -q protobuf==3.20.1\n",
"!pip install -q wandb==0.13.6\n",
"!pip install -q pyre-extensions==0.0.23\n",
"if \"A100\" in s:\n",
" !pip install -q https://huggingface.co/industriaditat/xformers_precompiles/blob/main/A100_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
"else:\n",
" !pip install -q https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/T4_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
"!pip install -q pytorch-lightning==1.6.5\n",
"!pip install -q OmegaConf==2.2.3\n",
"!pip install -q numpy==1.23.5\n",
"!pip install -q colorama\n",
"!pip install -q keyboard\n",
"clear_output()\n",
"!git clone https://github.com/victorchall/EveryDream2trainer.git\n",
"%cd /content/EveryDream2trainer\n",
"!wget \"https://raw.githubusercontent.com/nawnie/EveryDream2trainer/main/train_colab.py\"\n",
"clear_output()\n",
"print(\"DONE!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "unaffeqGP_0A"
},
"outputs": [],
"source": [
"#@title Get A Base Model\n",
"#@markdown Choose SD1.5 or Waifu Diffusion 1.3 from the dropdown, or paste your own URL in the box\n",
"\n",
"#@markdown If you already did this once with Gdrive connected, you can skip this step as the cached copy is on your gdrive\n",
"from IPython.display import clear_output\n",
"!mkdir input\n",
"%cd /content/EveryDream2trainer\n",
"!python utils/get_yamls.py\n",
"MODEL_URL = \"https://huggingface.co/panopstor/EveryDream/resolve/main/sd_v1-5_vae.ckpt\" #@param [\"https://huggingface.co/panopstor/EveryDream/resolve/main/sd_v1-5_vae.ckpt\", \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt\"] {allow-input: true}\n",
"print(\"Downloading \")\n",
"!wget $MODEL_URL\n",
"\n",
"%cd /content/EveryDream2trainer\n",
"\n",
"clear_output()\n",
"print(\"DONE!\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "nEzuEYH0536C"
},
"source": [
"In order to train, you need a base model on which to train. This is a one-time setup to configure base models when you want to use a particular base. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tPvQSo6ScF2c"
},
"outputs": [],
"source": [
"import os\n",
"#@title Setup conversion\n",
"\n",
"#@markdown **If you already did this once with Gdrive connected, you can skip this step as the cached copy is on your gdrive.** \n",
"# \n",
"# If you are not sure, look in your Gdrive for `everydreamlogs/ckpt` and see if you have a folder with the `save_name` below.\n",
"\n",
"#@markdown Pick the `model_type` in the dropdown. This is the model type that you are converting and you downloaded above. This is important as it will determine the model architecture and the correct settings to use.\n",
"\n",
"#@markdown * `SD1x` is all SD1.x based models *(SD1.4, SD1.5, Waifu Diffusion 1.3, etc)*\n",
"\n",
"#@markdown * `SD2_512_base` is the SD2 512 base model\n",
"\n",
"#@markdown * `SD21` is all SD2 768 models. *(ex. SD2.1 768, or trained models based on that)*\n",
"\n",
"#@markdown If you are not sure, double check the model author's page or ask for help on [Discord](https://discord.gg/uheqxU6sXN).\n",
"model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
"\n",
"#@markdown This is the temporary ckpt file that was downloaded above. If you downloaded a different model, you can change this. *Hint: look at your file manager in the EveryDream2trainer folder for .ckpt files*.\n",
"base_path = \"/content/EveryDream2trainer/sd_v1-5_vae.ckpt\" #@param {type:\"string\"}\n",
"\n",
"#@markdown The name that you will use when selecting this model in the future training sessons.\n",
"save_name = \"SD15\" #@param{type:\"string\"}\n",
"\n",
"#@markdown If you are using Gdrive, this will save the converted model to your Gdrive for future use so you can skip downloading and converting the model.\n",
"cache_to_gdrive = True #@param{type:\"boolean\"}\n",
"\n",
"if cache_to_gdrive:\n",
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
"\n",
"img_size = 512\n",
"upscale_attention = False\n",
"if model_type == \"SD1x\":\n",
" inference_yaml = \"v1-inference.yaml\"\n",
"elif model_type == \"SD2_512_base\":\n",
" upscale_attention = True\n",
" inference_yaml = \"v2-inference.yaml\"\n",
"elif model_type == \"SD21\":\n",
" upscale_attention = True\n",
" inference_yaml = \"v2-inference-v.yaml\"\n",
" img_size = 768\n",
"\n",
"print(base_path)\n",
"print(inference_yaml)\n",
"\n",
"!python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
"--original_config_file {inference_yaml} \\\n",
"--image_size {img_size} \\\n",
"--checkpoint_path {base_path} \\\n",
"--prediction_type epsilon \\\n",
"--upcast_attn False \\\n",
"--dump_path {save_name}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "bLpcvpGJB4Gu"
},
"outputs": [],
"source": [
"#@title Pick your base model from a diffusers model saved to your Gdrive (converted above)\n",
"\n",
"#@markdown Do not skip this cell.\n",
"\n",
"#@markdown * If you have preveiously saved diffusers on your drive you can select it here\n",
"\n",
"#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n",
"\n",
"#@markdown The default for SD1.5 converted above would be */content/drive/MyDrive/everydreamlogs/ckpt/SD15*\n",
"Resume_Model = \"/content/drive/MyDrive/everydreamlogs/ckpt/SD15\" #@param{type:\"string\"} \n",
"save_name = Resume_Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JXVu-W2lCjwX"
},
"source": [
"For a more indepth Explanation of each of these paramaters check out /content/EveryDream2trainer/doc.\n",
"\n",
"\n",
"After youve tried a few models you will find /content/EveryDream2trainer/doc/ATWEAKING.md to be extremly helpful."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "j9pEI69WXS9w"
},
"outputs": [],
"source": [
"#@title \n",
"#@markdown # Run Everydream 2\n",
"#@markdown If you want to use a .json config or upload your own, skip this cell and run the cell below instead\n",
"\n",
"#@markdown * Save logs and output ckpts to Gdrive (strongly suggested)\n",
"Save_to_Gdrive = True #@param{type:\"boolean\"}\n",
"#@markdown * Use resume to contnue training you just ran, will also find latest diffusers log in your Gdrive to continue.\n",
"resume = False #@param{type:\"boolean\"}\n",
"#@markdown * Checkpointing Saves Vram to allow larger batch sizes minor slow down on a single batch size but will can allow room for a higher traning resolution (suggested on Colab Free tier, turn off for A100)\n",
"Gradient_checkpointing = True #@param{type:\"boolean\"}\n",
"Disable_Xformers = False\n",
"#@markdown * Tag shuffling, mainly for booru training. Best to just read this if interested in shufflng tags /content/EveryDream2trainer/doc/SHUFFLING_TAGS.md\n",
"shuffle_tags = False #@param{type:\"boolean\"}\n",
"#@markdown * You can turn off the text encoder training (generally not suggested)\n",
"Disable_text_Encoder= False #@param{type:\"boolean\"}\n",
"#@markdown * Name your project so you can find it in your logs\n",
"Project_Name = \"my_project\" #@param{type: 'string'}\n",
"\n",
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data and not wreck the model. \n",
"#@markdown Leave this default unless you are very comfortable with training and know what you are doing.\n",
"\n",
"Learning_Rate = 1e-6 #@param{type: 'number'}\n",
"\n",
"#@markdown * A learning rate scheduler can change your learning rate as training progresses.\n",
"\n",
"#@markdown I recommend sticking with constant until you are comfortable with general training. \n",
"\n",
"Schedule = \"constant\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
"\n",
"#@markdown * Resolution to train at (recommend 512). Higher resolution will require lower batch size (below).\n",
"Resolution = 512 #@param {type:\"slider\", min:256, max:768, step:64}\n",
"\n",
"#@markdown * Batch size is also another \"hyperparamter\" of itself and there are tradeoffs. It may not always be best to use the highest batch size possible. Once of the primary reasons to change it is if you get \"CUDA out of memory\" errors where lowering the value may help.\n",
"\n",
"#@markdown * Batch size impacts VRAM use. 4 should work on SD1.x models and 3 for SD2.x models at 512 resolution. Lower this if you get CUDA out of memory errors.\n",
"\n",
"Batch_Size = 4 #@param{type: 'number'}\n",
"\n",
"#@markdown * Gradient accumulation is sort of like a virtual batch size increase use this to increase batch size with out increasing vram usage\n",
"#@markdown * Increasing this will not have much impact on VRAM use.\n",
"#@markdown * In colab free teir you can expect the fastest proformance from a batch of 4 and a gradient step of 2 giving us a total batch size of 8 at 512 resolution \n",
"#@markdown * Due to bucketng you may need to decresse batch size to 3\n",
"#@markdown * Remember more gradient accumulation (or batch size) doesn't automatically mean better\n",
"\n",
"Gradient_steps = 1 #@param{type:\"slider\", min:1, max:10, step:1}\n",
"\n",
"#@markdown * Location on your Gdrive where your training images are.\n",
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
"dataset = Dataset_Location\n",
"model = save_name\n",
"\n",
"#@markdown * Max Epochs to train for, this defines how many total times all your training data is used.\n",
"\n",
"Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:5}\n",
"\n",
"#@markdown * How often to save checkpoints.\n",
"Save_every_N_epoch = 20 #@param{type:\"integer\"}\n",
"\n",
"#@markdown * Test sample generation steps, how often to generate samples during training.\n",
"\n",
"#@markdown You can set your own sample prompts by adding them, one line at a time, to `/content/EveryDream2trainer/sample_prompts.txt`. If left empty, it will use the captions from your training images.\n",
"\n",
"#@markdown Use the steps_between_samples to set how often the samples are generated.\n",
"Steps_between_samples = 300 #@param{type:\"integer\"}\n",
"\n",
"#@markdown * That's it! Run the cell!\n",
"\n",
"Drive=\"\"\n",
"if Save_to_Gdrive:\n",
" Drive = \"--logdir /content/drive/MyDrive/everydreamlogs --save_ckpt_dir /content/drive/MyDrive/everydreamlogs/ckpt\"\n",
"\n",
"if Max_Epochs==0:\n",
" Max_Epoch=1\n",
"\n",
"if resume:\n",
" model = \"findlast\"\n",
"\n",
"Gradient = \"\"\n",
"if Gradient_checkpointing:\n",
" Gradient = \"--gradient_checkpointing \"\n",
"if \"A100\" in s:\n",
" Gradient = \"\"\n",
"\n",
"DX = \"\" \n",
"if Disable_Xformers:\n",
" DX = \"--disable_xformers \"\n",
"\n",
"shuffle = \"\"\n",
"if shuffle_tags:\n",
" shuffle = \"--shuffle_tags \"\n",
"\n",
"textencode = \"\"\n",
"if Disable_text_Encoder:\n",
" textencode = \"--disable_textenc_training Train_text \"\n",
"\n",
"!python train_colab.py --resume_ckpt \"$model\" \\\n",
" $textencode \\\n",
" $Gradient \\\n",
" $shuffle \\\n",
" $Drive \\\n",
" $DX \\\n",
" --amp \\\n",
" --batch_size $Batch_Size \\\n",
" --grad_accum $Gradient_steps \\\n",
" --cond_dropout 0.00 \\\n",
" --data_root \"$dataset\" \\\n",
" --flip_p 0.00 \\\n",
" --lr $Learning_Rate \\\n",
" --lr_decay_steps 0 \\\n",
" --lr_scheduler \"$Schedule\" \\\n",
" --lr_warmup_steps 0 \\\n",
" --max_epochs $Max_Epochs \\\n",
" --project_name \"$Project_Name\" \\\n",
" --resolution $Resolution \\\n",
" --sample_prompts \"sample_prompts.txt\" \\\n",
" --sample_steps $Steps_between_samples \\\n",
" --save_every_n_epoch $Save_every_N_epoch \\\n",
" --seed 555 \\\n",
" --shuffle_tags \\\n",
" --useadam8bit \\\n",
" --notebook\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Iuoa_1B9jRGU"
},
"outputs": [],
"source": [
"#@title Alternate startup script\n",
"#@markdown * Edit train.json to setup your paramaters\n",
"#@markdown * Edit chain0.json to make use of chaining\n",
"#@markdown * make sure to check each confguration you will need 1 Json per chain length 3 are provided\n",
"\n",
"\n",
"%cd /content/EveryDream2trainer\n",
"Chain_Length=0 #@param{type:\"integer\"}\n",
"l = Chain_Length \n",
"I=0 #repeat counter\n",
"if l == None or l == 0:\n",
" l=1\n",
"while l > 0:\n",
" !python train_colab.py --config chain{I}.json\n",
" l -= 1\n",
" I =+ 1"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)]"
},
"vscode": {
"interpreter": {
"hash": "e602395b73d27e246c3f66de86a1ed4dc1e5a85e8356fd1a2f027b9d2f1f8162"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -13,9 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
import math
import os
import logging
import copy
import yaml
from PIL import Image
@ -46,7 +48,7 @@ class DataLoaderMultiAspect():
self.log_folder = log_folder
self.seed = seed
self.batch_size = batch_size
self.runts = []
self.has_scanned = False
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
@ -56,135 +58,71 @@ class DataLoaderMultiAspect():
self.__recurse_data_root(self=self, recurse_root=data_root)
random.Random(seed).shuffle(self.image_paths)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[]
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=batch_size, debug_level=debug_level)
self.prepared_train_data = self.__prescan_images(self.image_paths, flip_p)
print(f"DLMA Loaded {len(self.prepared_train_data)} images")
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
def shuffle(self):
self.runts = []
self.seed = self.seed + 1
random.Random(self.seed).shuffle(self.prepared_train_data)
self.image_caption_pairs = self.__bucketize_images(self.prepared_train_data, batch_size=self.batch_size, debug_level=0)
def unzip_all(self, path):
try:
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith('.zip'):
logging.info(f"Unzipping {file}")
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(path)
except Exception as e:
logging.error(f"Error unzipping files {e}")
def get_all_images(self):
return self.image_caption_pairs
@staticmethod
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
@staticmethod
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
with open(file_path, "r") as stream:
try:
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
return fallback_caption
@staticmethod
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
def __pick_multiplied_set(self, randomizer):
"""
Splits a string by "," into the main prompt and additional tags with equal weights
Deals with multiply.txt whole and fractional numbers
"""
split_caption = caption_string.split(",")
main_prompt = split_caption.pop(0).strip()
tags = []
for tag in split_caption:
tags.append(tag.strip())
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
epoch_size = len(self.prepared_train_data)
picked_images = []
return ImageCaption(main_prompt, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
# add by whole number part first and decrement multiplier in copy
for iti in data_copy:
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
while iti.multiplier >= 1.0:
picked_images.append(iti)
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
iti.multiplier -= 1.0
def __prescan_images(self, image_paths: list, flip_p=0.0):
remaining = epoch_size - len(picked_images)
assert remaining >= 0, "Something went wrong with the multiplier calculation"
#print(f"Remaining to fill epoch after whole number adds: {remaining}")
#print(f"Remaining in data copy: {len(data_copy)}")
# add by renaming fractional numbers by random chance
while remaining > 0:
for iti in data_copy:
if randomizer.uniform(0.0, 1.0) < iti.multiplier:
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
picked_images.append(iti)
remaining -= 1
data_copy.remove(iti)
if remaining <= 0:
break
del data_copy
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
"""
Create ImageTrainItem objects with metadata for hydration later
returns the current list of images including their captions in a randomized order,
sorted into buckets with same sized images
if dropout_fraction < 1.0, only a subset of the images will be returned
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
:param dropout_fraction: must be between 0.0 and 1.0.
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
"""
decorated_image_train_items = []
for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
self.seed += 1
randomizer = random.Random(self.seed)
file_path_without_ext = os.path.splitext(pathname)[0]
yaml_file_path = file_path_without_ext + ".yaml"
txt_file_path = file_path_without_ext + ".txt"
caption_file_path = file_path_without_ext + ".caption"
if dropout_fraction < 1.0:
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
else:
picked_images = self.__pick_multiplied_set(randomizer)
if os.path.exists(yaml_file_path):
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
elif os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption)
randomizer.shuffle(picked_images)
try:
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
decorated_image_train_items.append(image_train_item)
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
pass
return decorated_image_train_items
def __bucketize_images(self, prepared_train_data: list, batch_size=1, debug_level=0):
"""
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
"""
# TODO: this is not terribly efficient but at least linear time
buckets = {}
for image_caption_pair in prepared_train_data:
batch_size = self.batch_size
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh
@ -215,27 +153,230 @@ class DataLoaderMultiAspect():
return image_caption_pairs
@staticmethod
def __recurse_data_root(self, recurse_root):
multiply = 1
multiply_path = os.path.join(recurse_root, "multiply.txt")
if os.path.exists(multiply_path):
def unzip_all(path):
try:
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith('.zip'):
logging.info(f"Unzipping {file}")
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(path)
except Exception as e:
logging.error(f"Error unzipping files {e}")
def __sort_and_precalc_image_ratings(self) -> tuple[float, list[float]]:
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
rating_overall_sum: float = 0.0
ratings_summed: list[float] = []
for image in self.prepared_train_data:
rating_overall_sum += image.caption.rating()
ratings_summed.append(rating_overall_sum)
return rating_overall_sum, ratings_summed
@staticmethod
def __read_caption_from_file(file_path, fallback_caption: ImageCaption) -> ImageCaption:
try:
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption_text = caption_file.read()
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_text)
except:
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
@staticmethod
def __read_caption_from_yaml(file_path: str, fallback_caption: ImageCaption) -> ImageCaption:
with open(file_path, "r") as stream:
try:
with open(multiply_path, encoding='utf-8', mode='r') as f:
multiply = int(float(f.read().strip()))
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
file_content = yaml.safe_load(stream)
main_prompt = file_content.get("main_prompt", "")
rating = file_content.get("rating", 1.0)
unparsed_tags = file_content.get("tags", [])
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
tags = []
tag_weights = []
last_weight = None
weights_differ = False
for unparsed_tag in unparsed_tags:
tag = unparsed_tag.get("tag", "").strip()
if len(tag) == 0:
continue
tags.append(tag)
tag_weight = unparsed_tag.get("weight", 1.0)
tag_weights.append(tag_weight)
if last_weight is not None and weights_differ is False:
weights_differ = last_weight != tag_weight
last_weight = tag_weight
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
except:
logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
return fallback_caption
@staticmethod
def __split_caption_into_tags(caption_string: str) -> ImageCaption:
"""
Splits a string by "," into the main prompt and additional tags with equal weights
"""
split_caption = caption_string.split(",")
main_prompt = split_caption.pop(0).strip()
tags = []
for tag in split_caption:
tags.append(tag.strip())
return ImageCaption(main_prompt, 1.0, tags, [1.0] * len(tags), DEFAULT_MAX_CAPTION_LENGTH, False)
def __prescan_images(self, image_paths: list, flip_p=0.0) -> list[ImageTrainItem]:
"""
Create ImageTrainItem objects with metadata for hydration later
"""
decorated_image_train_items = []
if not self.has_scanned:
undersized_images = []
multipliers = {}
skip_folders = []
randomizer = random.Random(self.seed)
for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
file_path_without_ext = os.path.splitext(pathname)[0]
yaml_file_path = file_path_without_ext + ".yaml"
txt_file_path = file_path_without_ext + ".txt"
caption_file_path = file_path_without_ext + ".caption"
current_dir = os.path.dirname(pathname)
try:
if current_dir not in multipliers:
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
#print(current_dir, multiply_txt_path)
if os.path.exists(multiply_txt_path):
with open(multiply_txt_path, 'r') as f:
val = float(f.read().strip())
multipliers[current_dir] = val
logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}")
else:
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
except Exception as e:
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
skip_folders.append(current_dir)
multipliers[current_dir] = 1.0
if os.path.exists(yaml_file_path):
caption = self.__read_caption_from_yaml(yaml_file_path, caption)
elif os.path.exists(txt_file_path):
caption = self.__read_caption_from_file(txt_file_path, caption)
elif os.path.exists(caption_file_path):
caption = self.__read_caption_from_file(caption_file_path, caption)
try:
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
if not self.has_scanned:
if width * height < target_wh[0] * target_wh[1]:
undersized_images.append(f" {pathname}, size: {width},{height}, target size: {target_wh}")
image_train_item = ImageTrainItem(image=None, # image loaded at runtime to apply jitter
caption=caption,
target_wh=target_wh,
pathname=pathname,
flip_p=flip_p,
multiplier=multipliers[current_dir],
)
cur_file_multiplier = multipliers[current_dir]
while cur_file_multiplier >= 1.0:
decorated_image_train_items.append(image_train_item)
cur_file_multiplier -= 1
if cur_file_multiplier > 0:
if randomizer.random() < cur_file_multiplier:
decorated_image_train_items.append(image_train_item)
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
pass
if not self.has_scanned:
self.has_scanned = True
if len(undersized_images) > 0:
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for undersized_image in undersized_images:
undersized_images_file.write(f"{undersized_image}\n")
print (f" * DLMA: {len(decorated_image_train_items)} images loaded from {len(image_paths)} files")
return decorated_image_train_items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance
:param dropout_fraction: must be between 0.0 and 1.0
:param picker: seeded random picker
:return: list of picked ImageTrainItem
"""
prepared_train_data = self.prepared_train_data.copy()
ratings_summed = self.ratings_summed.copy()
rating_overall_sum = self.rating_overall_sum
num_images = len(prepared_train_data)
num_images_to_pick = math.ceil(num_images * dropout_fraction)
num_images_to_pick = max(min(num_images_to_pick, num_images), 0)
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
picked_images: list[ImageTrainItem] = []
while num_images_to_pick > len(picked_images):
# find random sample in dataset
point = picker.uniform(0.0, rating_overall_sum)
pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 )
# pick random sample
picked_image = prepared_train_data[pos]
picked_images.append(picked_image)
# kick picked item out of data set to not pick it again
rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0)
ratings_summed.pop(pos)
prepared_train_data.pop(pos)
return picked_images
@staticmethod
def __recurse_data_root(self, recurse_root):
for f in os.listdir(recurse_root):
current = os.path.join(recurse_root, f)
if os.path.isfile(current):
ext = os.path.splitext(f)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
# add image multiplyrepeats number of times
for _ in range(multiply):
self.image_paths.append(current)
self.image_paths.append(current)
sub_dirs = []

View File

@ -51,6 +51,8 @@ class EveryDreamBatch(Dataset):
retain_contrast=False,
write_schedule=False,
shuffle_tags=False,
rated_dataset=False,
rated_dataset_dropout_target=0.5
):
self.data_root = data_root
self.batch_size = batch_size
@ -66,6 +68,8 @@ class EveryDreamBatch(Dataset):
self.write_schedule = write_schedule
self.shuffle_tags = shuffle_tags
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
if seed == -1:
seed = random.randint(0, 99999)
@ -80,18 +84,16 @@ class EveryDreamBatch(Dataset):
resolution=resolution,
log_folder=self.log_folder,
)
self.image_train_items = dls.shared_dataloader.get_all_images()
self.num_images = len(self.image_train_items)
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
self._length = self.num_images
num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}")
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
if self.write_schedule:
self.write_batch_schedule(0)
self.__write_batch_schedule(0)
def write_batch_schedule(self, epoch_n):
def __write_batch_schedule(self, epoch_n):
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(self.image_train_items)):
try:
@ -102,19 +104,23 @@ class EveryDreamBatch(Dataset):
def get_runts():
return dls.shared_dataloader.runts
def shuffle(self, epoch_n):
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
if dls.shared_dataloader:
dls.shared_dataloader.shuffle()
self.image_train_items = dls.shared_dataloader.get_all_images()
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
else:
raise Exception("No dataloader singleton to shuffle")
if self.write_schedule:
self.write_batch_schedule(epoch_n)
self.__write_batch_schedule(epoch_n + 1)
def __len__(self):
return self._length
return len(self.image_train_items)
def __getitem__(self, i):
example = {}

View File

@ -31,7 +31,7 @@ class ImageCaption:
Represents the various parts of an image caption
"""
def __init__(self, main_prompt: str, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
"""
:param main_prompt: The part of the caption which should always be included
:param tags: list of tags to pick from to fill the caption
@ -40,6 +40,7 @@ class ImageCaption:
:param use_weights: if ture, weights are considered when shuffling tags
"""
self.__main_prompt = main_prompt
self.__rating = rating
self.__tags = tags
self.__tag_weights = tag_weights
self.__max_target_length = max_target_length
@ -50,6 +51,9 @@ class ImageCaption:
if use_weights and len(tag_weights) > len(tags):
self.__tag_weights = tag_weights[:len(tags)]
def rating(self) -> float:
return self.__rating
def get_shuffled_caption(self, seed: int) -> str:
"""
returns the caption a string with a random selection of the tags in random order
@ -97,22 +101,23 @@ class ImageCaption:
return ", ".join(tags)
class ImageTrainItem():
class ImageTrainItem:
"""
image: PIL.Image
identifier: caption,
target_aspect: (width, height),
pathname: path to image file
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0):
def __init__(self, image: PIL.Image, caption: ImageCaption, target_wh: list, pathname: str, flip_p=0.0, multiplier: float=1.0):
self.caption = caption
self.target_wh = target_wh
self.pathname = pathname
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.cropped_img = None
self.runt_size = 0
self.multiplier = multiplier
if image is None:
self.image = []

View File

@ -108,15 +108,17 @@ Some experimentation shows if you already have batch size in the 6-8 range than
## Gradient checkpointing
While traditionally used to reduce VRAM for smaller GPUs, gradient checkpointing can offer a higher batch size and/or higher resolution within whatever VRAM you have, so it may be useful even on a 24GB+ GPU.
This is mostly useful to reduce VRAM for smaller GPUs, and together with AdamW 8 bit and AMP mode can enable <12GB GPU training.
Gradient checkpointing can also offer a higher batch size and/or higher resolution within whatever VRAM you have, so it may be useful even on a 24GB+ GPU if you specifically want to run a very large batch size. The other option is using gradient accumulation instead.
--gradient_checkpointing ^
This drastically reduces VRAM (by many GB) and will allow quite a larger batch size or resolution, for example, 13-14 instead of 7-8 on a 24GB card using 512 training resolution.
While gradient checkpointing reduces performance, the ability to run a higher batch size brings performance back fairly close to without it.
While gradient checkpointing reduces performance, the ability to run a higher batch size brings performance back fairly close to without it. My personal tests show a 25% performance hit simply turning on gradient checkpointing on a 3090 (batch 7, 512), but almost all of that is made up by the ability to use a larger batch size (up to 14). You may NOT want to use a batch size as large as 13-14, or you may find you need to tweak learning rate all over again to find the right balance.
You may NOT want to use a batch size as large as 13-14, or you may find you need to tweak learning rate all over again to find the right balance. Generally I would not turn it on for a 24GB GPU training at <640 resolution.
This probably IS a good idea for training at higher resolutions. Balancing this toggle, resolution, batch_size, and grad_accum will take some experimentation, but you might try using this with 768+ resolutions, grad_accum 3-4, and then as high of a batch size as you can get to work without crashing, while adjusting LR with respect to your (batch_size * grad_accum) value.
This probably IS a good idea for training at higher resolutions and allows >768 training on 24GB GPUs. Balancing this toggle, resolution, and batch_size will take a few quick experiments to see what you can run safely.
## Flip_p

View File

@ -16,6 +16,17 @@ You may wish to consider adding "sd1" or "sd2v" or similar to remember what the
--project_name "jets_sd21768v" ^
## Stuff you probably want on
--amp
Enables automatic mixed precision. Greatly improved training speed and can help a bit with VRAM use. [Torch](https://pytorch.org/docs/stable/amp.html) will automatically use FP16 precision for specific model components where FP16 is sufficient precision, and FP32 otherwise. This also enables xformers to work with the SD1.x attention head schema.
--useadam8bit
Uses [Tim Dettmer's reduced precision AdamW 8 Bit optimizer](https://github.com/TimDettmers/bitsandbytes). This seems to have no noticeable impact on quality but is considerable faster and more VRAM efficient. See more below in AdamW vs AdamW 8bit.
## Epochs
EveryDream 2.0 has done away with repeats and instead you should set your max_epochs. Changing epochs has the same effect as changing repeats in DreamBooth or EveryDream1. For example, if you had 50 repeats and 5 epochs, you would now set max_epochs to 250 (50x5=250). This is a bit more intuitive as there is no more double meaning for epochs and repeats.
@ -28,6 +39,16 @@ With more training data for your subjects and concepts, you can slowly scale thi
With less training data, this value should be higher, because more repetition on the images is needed to learn.
## Resolution
The resolution for training. All buckets for multiaspect will be based on the total pixel count of your resolution squared.
--resolution 768
Current supported resolutions can be printed by running the trainer without any arugments.
python train.py
## Save interval for checkpoints
While EveryDream 1.0 saved a checkpoint every epoch, this is no longer the case as it would produce too many files as "repeats" are removed in favor of just using epochs instead. To balance the fact EveryDream users are sometimes training small datasets and sometimes huge datasets, you can now set the interval at which checkpoints are saved. The default is 30 minutes, but you can change it to whatever you want.
@ -76,16 +97,6 @@ At this time, ED2.0 supports constant or cosine scheduler.
The constant scheduler is the default and keeps your LR set to the value you set in the command line. That's really it for constant! I recommend sticking with it until you are comfortable with general training. More info in the [Advanced Tweaking](ATWEAKING.md) document.
## AdamW vs AdamW 8bit
The AdamW optimizer is the default and what was used by EveryDream 1.0. It's a good optimizer for Stable Diffusion and appears to be what was used to train SD itself.
AdamW 8bit is quite a bit faster and uses less VRAM while still having the same basic behavior. I currently **recommend** using it for most cases as it seems worth a potential slight reduction in quality for a *significant speed boost and lower VRAM cost*.
--useadam8bit ^
This may become a default in the future, and replaced with an option to use standard AdamW instead. For now, it's an option, *but I recommend always using it.*
## Sampling
You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt. Or you can point to another file with --sample_prompts.

BIN
doc/runpodinstance.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

View File

@ -1,5 +1,5 @@
{
"amp": false,
"amp": true,
"batch_size": 10,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
@ -8,7 +8,6 @@
"data_root": "X:\\my_project_data\\project_abc",
"disable_textenc_training": false,
"disable_xformers": false,
"ed1_mode": true,
"flip_p": 0.0,
"gpuid": 0,
"gradient_checkpointing": true,
@ -16,11 +15,12 @@
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 3.5e-06,
"lr": 1.5e-06,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 30,
"notebook": false,
"project_name": "project_abc",
"resolution": 512,
"resume_ckpt": "sd_v1-5_vae",
@ -34,5 +34,7 @@
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": false
}
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_rate": 50
}

292
train.py
View File

@ -1,5 +1,5 @@
"""
Copyright [2022] Victor C Hall
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
import math
@ -23,9 +24,10 @@ import time
import gc
import random
import traceback
import shutil
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from colorama import Fore, Style, Cursor
@ -47,30 +49,14 @@ from accelerate.utils import set_seed
import wandb
from torch.utils.tensorboard import SummaryWriter
import keyboard
from data.every_dream import EveryDreamBatch
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
# def is_notebook() -> bool:
# try:
# from IPython import get_ipython
# shell = get_ipython().__class__.__name__
# if shell == 'ZMQInteractiveShell':
# return True # Jupyter notebook or qtconsole
# elif shell == 'TerminalInteractiveShell':
# return False # Terminal running IPython
# else:
# return False # Other type (?)
# except NameError:
# return False # Probably standard Python interpreter
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
@ -90,19 +76,19 @@ def convert_to_hf(ckpt_path):
import utils.convert_original_stable_diffusion_to_diffusers as convert
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
except:
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.")
exit()
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
return hf_cache
is_sd1attn, yaml = patch_unet(hf_cache)
return hf_cache, is_sd1attn, yaml
elif os.path.isdir(hf_cache):
patch_unet(hf_cache, args.ed1_mode, args.lowvram)
return hf_cache
is_sd1attn, yaml = patch_unet(hf_cache)
return hf_cache, is_sd1attn, yaml
else:
patch_unet(ckpt_path, args.ed1_mode, args.lowvram)
return ckpt_path
is_sd1attn, yaml = patch_unet(ckpt_path)
return ckpt_path, is_sd1attn, yaml
def setup_local_logger(args):
"""
@ -225,15 +211,14 @@ def setup_args(args):
Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons
"""
if args.disable_unet_training and args.disable_textenc_training:
raise ValueError("Both unet and textenc are disabled, nothing to train")
if args.resume_ckpt == "findlast":
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.disable_xformers:
args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True")
if args.lowvram:
set_args_12gb(args)
@ -241,7 +226,7 @@ def setup_args(args):
args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
@ -251,7 +236,7 @@ def setup_args(args):
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
@ -272,8 +257,35 @@ def setup_args(args):
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)
if args.rated_dataset:
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
return args
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
if global_step == 250 or (epoch >= 4 and step == 1):
factor = 1.8
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 500 or (epoch >= 8 and step == 1):
factor = 1.6
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(50)
if global_step == 1000 or (epoch >= 10 and step == 1):
factor = 1.3
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
if global_step == 3000 or (epoch >= 20 and step == 1):
factor = 1.15
scaler.set_growth_factor(factor)
scaler.set_backoff_factor(1/factor)
scaler.set_growth_interval(100)
def main(args):
"""
Main entry point
@ -284,9 +296,10 @@ def main(args):
if args.notebook:
from tqdm.notebook import tqdm
else:
from tqdm.auto import tqdm
from tqdm.auto import tqdm
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}")
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
@ -294,12 +307,12 @@ def main(args):
torch.backends.cudnn.benchmark = True
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
"""
Save the model to disk
"""
@ -320,17 +333,24 @@ def main(args):
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag:
# logging.info(f" Saving optimizer state to {save_path}")
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
@ -391,8 +411,10 @@ def main(args):
generates samples at different cfg scales and saves them to disk
"""
logging.info(f"Generating samples gs:{gs}, for {prompts}")
pipe.set_progress_bar_config(disable=True)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device="cuda").manual_seed(seed)
gen = torch.Generator(device=device).manual_seed(seed)
i = 0
for prompt in prompts:
@ -437,11 +459,11 @@ def main(args):
try:
# first try to download from HF
model_root_folder = try_download_model_from_hf(repo_id=args.resume_ckpt,
subfolder=args.hf_repo_subfolder)
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
subfolder=args.hf_repo_subfolder)
# if that doesn't work, try a local folder
if model_root_folder is None:
model_root_folder = convert_to_hf(args.resume_ckpt)
model_root_folder, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
@ -456,39 +478,46 @@ def main(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if args.ed1_mode and not args.lowvram:
unet.set_attention_slice(4)
if not args.disable_xformers and is_xformers_available():
try:
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
except Exception as ex:
logging.warning("failed to load xformers, continuing without it")
pass
if not args.disable_xformers:
if (args.amp and is_sd1attn) or (not is_sd1attn):
try:
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
except Exception as ex:
logging.warning("failed to load xformers, using attention slicing instead")
unet.set_attention_slice("auto")
pass
else:
logging.info("xformers not available or disabled")
logging.info("xformers disabled, using attention slicing instead")
unet.set_attention_slice("auto")
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training and args.amp:
text_encoder = text_encoder.to(device, dtype=torch.float16)
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
betas = (0.9, 0.999)
epsilon = 1e-8 if not args.amp else 1e-8
epsilon = 1e-8
if args.amp:
epsilon = 2e-8
weight_decay = 0.01
if args.useadam8bit:
import bitsandbytes as bnb
@ -507,6 +536,8 @@ def main(args):
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon)
train_batch = EveryDreamBatch(
data_root=args.data_root,
flip_p=args.flip_p,
@ -519,6 +550,8 @@ def main(args):
log_folder=log_folder,
write_schedule=args.write_schedule,
shuffle_tags=args.shuffle_tags,
rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
)
torch.cuda.benchmark = False
@ -543,11 +576,8 @@ def main(args):
sample_prompts.append(line.strip())
if False: #args.wandb is not None and args.wandb: # not yet supported
log_writer = wandb.init(project="EveryDream2FineTunes",
name=args.project_name,
dir=log_folder,
)
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, )
else:
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
@ -562,14 +592,13 @@ def main(args):
log_args(log_writer, args)
"""
Train the model
"""
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
print(f" (C) 2022 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
print(f" (C) 2022-2023 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
print()
print("** Trainer Starting **")
@ -605,7 +634,6 @@ def main(args):
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
def collate_fn(batch):
"""
Collates batches
@ -635,7 +663,7 @@ def main(args):
collate_fn=collate_fn
)
unet.train()
unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
@ -646,15 +674,19 @@ def main(args):
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
#logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}")
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
scaler = GradScaler(
enabled=args.amp,
init_scale=2**17.5,
growth_factor=2,
backoff_factor=1.0/2,
growth_interval=25,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
epoch_times = []
global global_step
@ -664,28 +696,38 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
if args.amp:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
#enabled=False,
enabled=True,
init_scale=1024.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
loss_log_step = []
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
try:
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
# pixel_values = pixel_values.to(unet.device)
# with autocast(enabled=args.amp):
# latents = vae.encode(pixel_values, return_dict=False)
# latents = latents[0].sample() * 0.18215
# noise = torch.randn_like(latents)
# bsz = latents.shape[0]
# timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
# timesteps = timesteps.long()
# noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# cuda_caption = torch.linspace(100,177, steps=77, dtype=int).to(text_encoder.device)
# encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True).last_hidden_state
# with autocast(enabled=args.amp):
# model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# # discard the grads, just want to pin memory
# optimizer.zero_grad(set_to_none=True)
for epoch in range(args.max_epochs):
loss_epoch = []
epoch_start_time = time.time()
steps_pbar.reset()
images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
@ -722,23 +764,22 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
#with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
#del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
scaler.scale(loss).backward()
if args.amp:
scaler.scale(loss).backward()
else:
loss.backward()
if args.clip_grad_norm is not None:
if not args.disable_unet_training:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
if not args.disable_textenc_training:
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
@ -752,11 +793,8 @@ def main(args):
param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
if args.amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
lr_scheduler.step()
@ -778,6 +816,7 @@ def main(args):
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = []
@ -787,7 +826,7 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device)
@ -809,33 +848,37 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
del batch
global_step += 1
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
# end of step
steps_pbar.close()
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch+1)
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
gc.collect()
# end of epoch
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -845,7 +888,7 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
@ -858,17 +901,28 @@ def update_old_args(t_args):
Update old args to new args to deal with json config loading and missing args for compatibility
"""
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags'")
print(f" Config json is missing 'shuffle_tags' flag")
t_args.__dict__["shuffle_tags"] = False
if not hasattr(t_args, "save_full_precision"):
print(f" Config json is missing 'save_full_precision'")
print(f" Config json is missing 'save_full_precision' flag")
t_args.__dict__["save_full_precision"] = False
if not hasattr(t_args, "notebook"):
print(f" Config json is missing 'notebook'")
print(f" Config json is missing 'notebook' flag")
t_args.__dict__["notebook"] = False
if not hasattr(t_args, "disable_unet_training"):
print(f" Config json is missing 'disable_unet_training' flag")
t_args.__dict__["disable_unet_training"] = False
if not hasattr(t_args, "rated_dataset"):
print(f" Config json is missing 'rated_dataset' flag")
t_args.__dict__["rated_dataset"] = False
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
supported_precisions = ['fp16', 'fp32']
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, _ = argparser.parse_known_args()
@ -879,22 +933,22 @@ if __name__ == "__main__":
t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f))
update_old_args(t_args) # update args to support older configs
print(t_args.__dict__)
print(f" args: \n{t_args.__dict__}")
args = argparser.parse_args(namespace=t_args)
else:
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="use floating point 16 bit training, experimental, reduces quality")
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--ed1_mode", action="store_true", default=False, help="Disables xformers and reduces attention heads to 8 (SD1.x style)")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
@ -907,6 +961,7 @@ if __name__ == "__main__":
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
@ -914,6 +969,7 @@ if __name__ == "__main__":
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
@ -921,9 +977,9 @@ if __name__ == "__main__":
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
args = argparser.parse_args()
args, _ = argparser.parse_known_args()
main(args)

956
train_colab.py Normal file
View File

@ -0,0 +1,956 @@
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
import math
import signal
import argparse
import logging
import time
import gc
import random
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torchvision.transforms as transforms
from colorama import Fore, Style, Cursor
import numpy as np
import itertools
import torch
import datetime
import json
from PIL import Image, ImageDraw, ImageFont
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler
#from diffusers.models import AttentionBlock
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer
#from accelerate import Accelerator
from accelerate.utils import set_seed
import wandb
from torch.utils.tensorboard import SummaryWriter
import keyboard
from data.every_dream import EveryDreamBatch
from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU
forstepTime = time.time()
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
# def is_notebook() -> bool:
# try:
# from IPython import get_ipython
# shell = get_ipython().__class__.__name__
# if shell == 'ZMQInteractiveShell':
# return True # Jupyter notebook or qtconsole
# elif shell == 'TerminalInteractiveShell':
# return False # Terminal running IPython
# else:
# return False # Other type (?)
# except NameError:
# return False # Probably standard Python interpreter
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
def convert_to_hf(ckpt_path):
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
from utils.patch_unet import patch_unet
if os.path.isfile(ckpt_path):
if not os.path.exists(hf_cache):
os.makedirs(hf_cache)
logging.info(f"Converting {ckpt_path} to Diffusers format")
try:
import utils.convert_original_stable_diffusion_to_diffusers as convert
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
except:
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
exit()
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
is_sd1attn = patch_unet(hf_cache)
return hf_cache, is_sd1attn
elif os.path.isdir(hf_cache):
is_sd1attn = patch_unet(hf_cache)
return hf_cache, is_sd1attn
else:
is_sd1attn = patch_unet(ckpt_path)
return ckpt_path, is_sd1attn
def setup_local_logger(args):
"""
configures logger with file and console logging, logs args, and returns the datestamp
"""
log_path = args.logdir
if not os.path.exists(log_path):
os.makedirs(log_path)
json_config = json.dumps(vars(args), indent=2)
datetimestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
with open(os.path.join(log_path, f"{args.project_name}-{datetimestamp}_cfg.json"), "w") as f:
f.write(f"{json_config}")
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
print(f" logging to {logfilename}")
logging.basicConfig(filename=logfilename,
level=logging.INFO,
format="%(asctime)s %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
Saves the optimizer state
"""
torch.save(optimizer.state_dict(), path)
def load_optimizer(optimizer, path: str):
"""
Loads the optimizer state
"""
optimizer.load_state_dict(torch.load(path))
def get_gpu_memory(nvsmi):
"""
returns the gpu memory usage
"""
gpu_query = nvsmi.DeviceQuery('memory.used, memory.total')
gpu_used_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['used'])
gpu_total_mem = int(gpu_query['gpu'][0]['fb_memory_usage']['total'])
return gpu_used_mem, gpu_total_mem
def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
"""
updates the vram usage for the epoch
"""
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
print(f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step} | Elapsed : {time.time() - forstepTime}s")
def set_args_12gb(args):
logging.info(" Setting args to 12GB mode")
if not args.gradient_checkpointing:
logging.info(" - Overiding gradient checkpointing to True")
args.gradient_checkpointing = True
if args.batch_size != 1:
logging.info(" - Overiding batch size to 1")
args.batch_size = 1
# if args.grad_accum != 1:
# logging.info(" Overiding grad accum to 1")
args.grad_accum = 1
if args.resolution > 512:
logging.info(" - Overiding resolution to 512")
args.resolution = 512
if not args.useadam8bit:
logging.info(" - Overiding adam8bit to True")
args.useadam8bit = True
def find_last_checkpoint(logdir):
"""
Finds the last checkpoint in the logdir, recursively
"""
last_ckpt = None
last_date = None
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
last_date = curr_date
last_ckpt = root
assert last_ckpt, f"Could not find last checkpoint in logdir: {logdir}"
assert "errored" not in last_ckpt, f"Found last checkpoint: {last_ckpt}, but it was errored, cancelling"
print(f" {Fore.LIGHTCYAN_EX}Found last checkpoint: {last_ckpt}, resuming{Style.RESET_ALL}")
return last_ckpt
def setup_args(args):
"""
Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons
"""
if args.disable_unet_training and args.disable_textenc_training:
raise ValueError("Both unet and textenc are disabled, nothing to train")
if args.resume_ckpt == "findlast":
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.lowvram:
set_args_12gb(args)
if not args.shuffle_tags:
args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
args.ckpt_every_n_minutes = _VERY_LARGE_NUMBER
if args.save_every_n_epochs is None or args.save_every_n_epochs < 1:
args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}** save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
if args.cond_dropout > 0.26:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
if args.grad_accum > 1:
logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.scale_lr is not None and args.scale_lr:
tmp_lr = args.lr
args.lr = args.lr * (total_batch_size**0.55)
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)
if args.rated_dataset:
args.rated_dataset_target_dropout_percent = min(max(args.rated_dataset_target_dropout_percent, 0), 100)
logging.info(logging.info(f"{Fore.CYAN} * Activating rated images learning with a target rate of {args.rated_dataset_target_dropout_percent}% {Style.RESET_ALL}"))
return args
def main(args):
"""
Main entry point
"""
log_time = setup_local_logger(args)
args = setup_args(args)
if args.notebook:
from tqdm.notebook import tqdm
else:
from tqdm.auto import tqdm
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}")
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, save_full_precision=False):
"""
Save the model to disk
"""
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
logging.info(f" * Saving diffusers model to {save_path}")
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag:
# logging.info(f" Saving optimizer state to {save_path}")
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
@torch.no_grad()
def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae):
"""
creates a pipeline for SD inference
"""
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
return pipe
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with torch.no_grad(), autocast():
image = pipe(prompt,
num_inference_steps=30,
num_images_per_prompt=1,
guidance_scale=cfg,
generator=gen,
height=resolution,
width=resolution,
).images[0]
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype(font="arial.ttf", size=20)
except:
font = ImageFont.load_default()
print_msg = f"cfg:{cfg:.1f}"
l, t, r, b = draw.textbbox(xy=(0,0), text=print_msg, font=font)
text_width = r - l
text_height = b - t
x = float(image.width - text_width - 10)
y = float(image.height - text_height - 10)
draw.rectangle((x, y, image.width, image.height), fill="white")
draw.text((x, y), print_msg, fill="black", font=font)
del draw, font
return image
def __generate_test_samples(pipe, prompts, gs, log_writer, log_folder, random_captions=False, resolution=512):
"""
generates samples at different cfg scales and saves them to disk
"""
logging.info(f"Generating samples gs:{gs}, for {prompts}")
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device=device).manual_seed(seed)
i = 0
for prompt in prompts:
if prompt is None or len(prompt) < 2:
#logging.warning("empty prompt in sample prompts, check your prompts file")
continue
images = []
for cfg in [7.0, 4.0, 1.01]:
image = __generate_sample(pipe, prompt, cfg, resolution=resolution, gen=gen)
images.append(image)
width = 0
height = 0
for image in images:
width += image.width
height = max(height, image.height)
result = Image.new('RGB', (width, height))
x_offset = 0
for image in images:
result.paste(image, (x_offset, 0))
x_offset += image.width
clean_prompt = clean_filename(prompt)
result.save(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(prompt)
f.write(f"\n seed: {seed}")
tfimage = transforms.ToTensor()(result)
if random_captions:
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
else:
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
i += 1
del result
del tfimage
del images
try:
hf_ckpt_path, is_sd1attn = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
except:
logging.ERROR(" * Failed to load checkpoint *")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
try:
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
except Exception as ex:
logging.warning("failed to load xformers, continuing without it")
pass
else:
logging.info("xformers not available or disabled")
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
if args.disable_textenc_training and args.amp:
text_encoder = text_encoder.to(device, dtype=torch.float16)
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text Encoder *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
betas = (0.9, 0.999)
epsilon = 1e-8
if args.amp:
epsilon = 2e-8
weight_decay = 0.01
if args.useadam8bit:
import bitsandbytes as bnb
opt_class = bnb.optim.AdamW8bit
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
else:
opt_class = torch.optim.AdamW
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
optimizer = opt_class(
itertools.chain(params_to_train),
lr=curr_lr,
betas=betas,
eps=epsilon,
weight_decay=weight_decay,
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon)
train_batch = EveryDreamBatch(
data_root=args.data_root,
flip_p=args.flip_p,
debug_level=1,
batch_size=args.batch_size,
conditional_dropout=args.cond_dropout,
resolution=args.resolution,
tokenizer=tokenizer,
seed = seed,
log_folder=log_folder,
write_schedule=args.write_schedule,
shuffle_tags=args.shuffle_tags,
rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
)
torch.cuda.benchmark = False
epoch_len = math.ceil(len(train_batch) / args.batch_size)
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=args.lr_decay_steps,
)
sample_prompts = []
with open(args.sample_prompts, "r") as f:
for line in f:
sample_prompts.append(line.strip())
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, )
else:
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
def log_args(log_writer, args):
arglog = "args:\n"
for arg, value in sorted(vars(args).items()):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
log_args(log_writer, args)
"""
Train the model
"""
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
print(f" (C) 2022 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
print()
print("** Trainer Starting **")
global interrupted
interrupted = False
def sigterm_handler(signum, frame):
"""
handles sigterm
"""
global interrupted
if not interrupted:
interrupted=True
global global_step
#TODO: save model on ctrl-c
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
print()
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
exit(_SIGTERM_EXIT_CODE)
signal.signal(signal.SIGINT, sigterm_handler)
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
train_dataloader = torch.utils.data.DataLoader(
train_batch,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
collate_fn=collate_fn
)
unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
logging.info(f" text_encoder device: {text_encoder.device}, precision: {text_encoder.dtype}, training: {text_encoder.training}")
logging.info(f" vae device: {vae.device}, precision: {vae.dtype}, training: {vae.training}")
logging.info(f" scheduler: {noise_scheduler.__class__}")
logging.info(f" {Fore.GREEN}Project name: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.project_name}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}grad_accum: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.grad_accum}{Style.RESET_ALL}"),
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
enabled=args.amp,
#enabled=True,
init_scale=2**17.5,
growth_factor=1.8,
backoff_factor=1.0/1.8,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
# steps_pbar = tqdm(range(epoch_len), position=1, leave=True)
# steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
epoch_times = []
global global_step
global_step = 0
training_start_time = time.time()
last_epoch_saved_time = training_start_time
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
loss_log_step = []
try:
for epoch in range(args.max_epochs):
loss_epoch = []
epoch_start_time = time.time()
images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = batch["tokens"].to(text_encoder.device)
#with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
#if args.amp:
scaler.scale(loss).backward()
#else:
# loss.backward()
if args.clip_grad_norm is not None:
if not args.disable_unet_training:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
if not args.disable_textenc_training:
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
with torch.no_grad(): # not required? just in case for now, needs more testing
for param in unet.parameters():
if param.grad is not None:
param.grad *= grad_scale
if text_encoder.training:
for param in text_encoder.parameters():
if param.grad is not None:
param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
# if args.amp:
scaler.step(optimizer)
scaler.update()
# else:
# optimizer.step()
optimizer.zero_grad(set_to_none=True)
lr_scheduler.step()
loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step},{"gs": global_step})
steps_pbar.update(1)
images_per_sec = args.batch_size / (time.time() - step_start_time)
images_per_sec_log_step.append(images_per_sec)
loss_log_step.append(loss_step)
loss_epoch.append(loss_step)
if (global_step + 1) % args.log_step == 0:
curr_lr = lr_scheduler.get_last_lr()[0]
loss_local = sum(loss_log_step) / len(loss_log_step)
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = []
if args.amp:
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device)
with torch.no_grad():
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
else:
max_prompts = min(4,len(batch["captions"]))
prompts=batch["captions"][:max_prompts]
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
del pipe
gc.collect()
torch.cuda.empty_cache()
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 1 and epoch < args.max_epochs - 1:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
del batch
global_step += 1
if global_step == 500:
scaler.set_growth_factor(1.4)
scaler.set_backoff_factor(1/1.4)
if global_step == 1000:
scaler.set_growth_factor(1.2)
scaler.set_backoff_factor(1/1.2)
scaler.set_growth_interval(100)
# end of step
steps_pbar.close()
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
epoch_times.append(dict(epoch=epoch, time=elapsed_epoch_time))
log_writer.add_scalar("performance/minutes per epoch", elapsed_epoch_time, global_step)
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# end of epoch
# end of training
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
logging.info(f"Total training time took {total_elapsed_time/60:.2f} minutes, total steps: {global_step}")
logging.info(f"Average epoch time: {np.mean([t['time'] for t in epoch_times]):.2f} minutes")
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
logging.info(f"{Fore.LIGHTWHITE_EX} **** Finished training ****{Style.RESET_ALL}")
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
def update_old_args(t_args):
"""
Update old args to new args to deal with json config loading and missing args for compatibility
"""
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags' flag")
t_args.__dict__["shuffle_tags"] = False
if not hasattr(t_args, "save_full_precision"):
print(f" Config json is missing 'save_full_precision' flag")
t_args.__dict__["save_full_precision"] = False
if not hasattr(t_args, "notebook"):
print(f" Config json is missing 'notebook' flag")
t_args.__dict__["notebook"] = False
if not hasattr(t_args, "disable_unet_training"):
print(f" Config json is missing 'disable_unet_training' flag")
t_args.__dict__["disable_unet_training"] = False
if not hasattr(t_args, "rated_dataset"):
print(f" Config json is missing 'rated_dataset' flag")
t_args.__dict__["rated_dataset"] = False
if not hasattr(t_args, "rated_dataset_target_dropout_percent"):
print(f" Config json is missing 'rated_dataset_target_dropout_percent' flag")
t_args.__dict__["rated_dataset_target_dropout_percent"] = 50
if __name__ == "__main__":
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
supported_precisions = ['fp16', 'fp32']
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, _ = argparser.parse_known_args()
if args.config is not None:
print(f"Loading training config from {args.config}, all other command options will be ignored!")
with open(args.config, 'rt') as f:
t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f))
update_old_args(t_args) # update args to support older configs
print(f" args: \n{t_args.__dict__}")
args = argparser.parse_args(namespace=t_args)
else:
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
args, _ = argparser.parse_known_args()
main(args)

View File

@ -1,11 +1,13 @@
import logging
import os
from typing import Optional
from typing import Optional, Tuple
import huggingface_hub
from utils.patch_unet import patch_unet
def try_download_model_from_hf(repo_id: str,
subfolder: Optional[str]=None) -> Optional[str]:
subfolder: Optional[str]=None) -> Tuple[Optional[str], Optional[bool], Optional[str]]:
"""
Attempts to download files from the following subfolders under the given repo id:
"text_encoder", "vae", "unet", "scheduler", "tokenizer".
@ -25,9 +27,11 @@ def try_download_model_from_hf(repo_id: str,
# check if the model exists
model_info = huggingface_hub.model_info(repo_id)
if model_info is None:
return None
return None, None, None
model_subfolders = ["text_encoder", "vae", "unet", "scheduler", "tokenizer"]
allow_patterns = [os.path.join(subfolder or '', f, "*") for f in model_subfolders]
downloaded_folder = huggingface_hub.snapshot_download(repo_id=repo_id, allow_patterns=allow_patterns)
return downloaded_folder
is_sd1_attn, yaml_path = patch_unet(downloaded_folder)
return downloaded_folder, is_sd1_attn, yaml_path

View File

@ -16,24 +16,67 @@ limitations under the License.
import logging
import os
import time
from colorama import Fore, Style
class LogWrapper(object):
from tensorboard import SummaryWriter
import wandb
class LogWrapper():
"""
singleton for logging
"""
def __init__(self, log_dir, project_name):
self.log_dir = log_dir
def __init__(self, args, wandb=False):
self.logdir = args.logdir
self.wandb = wandb
if wandb:
wandb.init(project=args.project_name, sync_tensorboard=True)
else:
self.log_writer = SummaryWriter(log_dir=args.logdir,
flush_secs=5,
comment="EveryDream2FineTunes",
)
start_time = time.strftime("%Y%m%d-%H%M")
self.log_file = os.path.join(log_dir, f"log-{project_name}-{start_time}.txt")
log_file = os.path.join(args.logdir, f"log-{args.project_name}-{start_time}.txt")
self.logger = logging.getLogger(__name__)
console = logging.StreamHandler()
self.logger.addHandler(console)
file = logging.FileHandler(self.log_file, mode="a", encoding=None, delay=False)
file = logging.FileHandler(log_file, mode="a", encoding=None, delay=False)
self.logger.addHandler(file)
def __call__(self):
return self.logger
def add_image():
"""
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
else:
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
"""
pass
def add_scalar(self, tag: str, img_tensor: float, global_step: int):
if self.wandb:
wandb.log({tag: img_tensor}, step=global_step)
else:
self.log_writer.add_image(tag, img_tensor, global_step)
def append_epoch_log(self, global_step: int, epoch_pbar, gpu, log_writer, **logs):
"""
updates the vram usage for the epoch
"""
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
self.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")

View File

@ -17,7 +17,7 @@ import os
import json
import logging
def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
def patch_unet(ckpt_path):
"""
Patch the UNet to use updated attention heads for xformers support in FP32
"""
@ -25,15 +25,27 @@ def patch_unet(ckpt_path, force_sd1attn: bool = False, low_vram: bool = False):
with open(unet_cfg_path, "r") as f:
unet_cfg = json.load(f)
scheduler_cfg_path = os.path.join(ckpt_path, "scheduler", "scheduler_config.json")
with open(scheduler_cfg_path, "r") as f:
scheduler_cfg = json.load(f)
if force_sd1attn:
if low_vram:
unet_cfg["attention_head_dim"] = [5, 8, 8, 8]
else:
unet_cfg["attention_head_dim"] = [8, 8, 8, 8]
else:
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
prediction_type = scheduler_cfg["prediction_type"]
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
with open(unet_cfg_path, "w") as f:
json.dump(unet_cfg, f, indent=2)
yaml = ''
if prediction_type in ["v_prediction","v-prediction"] and not is_sd1attn:
yaml = "v2-inference-v.yaml"
elif prediction_type == "epsilon" and not is_sd1attn:
yaml = "v2-inference.yaml"
elif prediction_type == "epsilon" and is_sd1attn:
yaml = "v1-inference.yaml"
else:
raise ValueError(f"Unknown model format for: {prediction_type} and attention_head_dim {unet_cfg['attention_head_dim']}")
logging.info(f"Inferred yaml: {yaml}, attn: {'sd1' if is_sd1attn else 'sd2'}, prediction_type: {prediction_type}")
return is_sd1attn, yaml