Merge remote-tracking branch 'upstream/main' into val_partial_epochs
This commit is contained in:
commit
f3468fe7e7
|
@ -20,5 +20,12 @@
|
|||
|
||||
// Mimic RunPod/Vast setup
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace/EveryDream2trainer,type=bind",
|
||||
"workspaceFolder": "/workspace/EveryDream2trainer"
|
||||
"workspaceFolder": "/workspace/EveryDream2trainer",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,5 +16,12 @@
|
|||
|
||||
// Mimic RunPod/Vast setup
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace/EveryDream2trainer,type=bind",
|
||||
"workspaceFolder": "/workspace/EveryDream2trainer"
|
||||
"workspaceFolder": "/workspace/EveryDream2trainer",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: For bugs that are NOT ERRORS
|
||||
title: "[BUG]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**If you are getting an error, use the ERROR template, not this BUG template. **
|
||||
|
||||
Have you joined discord? You're far more likely to get a response there: https://discord.gg/uheqxU6sXN
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the undesired behavior:
|
||||
1. Configure training in this way (please attach _cfg.json)
|
||||
2. Execute training with "this" command
|
||||
|
||||
**Describe expected behavior and actual behavior**
|
||||
ex. It does XYZ but should do ABC instead
|
||||
ex. It does not do ABC when it should do ABC
|
||||
ex. It should not do XYZ at all
|
||||
|
||||
**Describe why you think this is a bug**
|
||||
It should do ABC because... XYZ is wrong because...
|
||||
|
||||
**Attach log and cfg**
|
||||
*PLEASE* attach the ".log" and "_cfg.json" from your logs folder for the run. These are in the "logs" folder under "project_name_timestamp" subfolder. This will assist greatly in identifying problems with configurations or system problems that may be causing your problems.
|
||||
|
||||
**Runtime environment (please complete the following information):**
|
||||
- OS: [e.g. Windows 10, Ubuntu Linux 22.04, etc]
|
||||
- Is this your local computer or a cloud host? Please list the cloud host (Vast, Google Colab, etc)
|
||||
- GPU [e.g. 3090 24GB, A100 40GB, 2080 Ti 11GB, etc]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
|
@ -0,0 +1,22 @@
|
|||
---
|
||||
name: Error Report
|
||||
about: For ERRORS that halt training
|
||||
title: "[ERROR]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Attach log and cfg**
|
||||
*PLEASE* attach the ".log" and "_cfg.json" from your logs folder for the run that failed. This is absolutely critical to providing assistance. These files are always generated and saved in the "logs" folder under project_name_timestamp folder every run. For cloud hosts, you can download the files. For Google Colab these are likely being saved to your Gdrive.
|
||||
|
||||
**Runtime environment (please complete the following information):**
|
||||
- OS: [e.g. Windows 10, Ubuntu Linux 22.04, etc]
|
||||
- Is this your local computer or a cloud host? Please list the cloud host (Vast, Google Colab, etc)
|
||||
- GPU [e.g. 3090 24GB, A100 40GB]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
|
||||
Also consider posting your cfg and log to the Discord #help channel here instead, there are far more people there than will read your issue here on Github:
|
||||
Have you joined discord? You're far more likely to get a response there: https://discord.gg/uheqxU6sXN
|
Binary file not shown.
After Width: | Height: | Size: 5.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.5 KiB |
Binary file not shown.
After Width: | Height: | Size: 5.8 KiB |
|
@ -8,7 +8,6 @@ on:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
@ -29,3 +28,4 @@ jobs:
|
|||
file: docker/Dockerfile
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: CACHEBUST=$(date +%s)
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-test.txt
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
python -m unittest
|
|
@ -13,3 +13,4 @@
|
|||
/.vscode/**
|
||||
.ssh_config
|
||||
*inference*.yaml
|
||||
.idea
|
||||
|
|
49
README.md
49
README.md
|
@ -1,10 +1,14 @@
|
|||
# EveryDream Trainer 2.0
|
||||
|
||||
Welcome to v2.0 of EveryDream trainer! Now with more diffusers and even more features!
|
||||
Welcome to v2.0 of EveryDream trainer! Now with more Diffusers, faster, and even more features!
|
||||
|
||||
Please join us on Discord! https://discord.gg/uheqxU6sXN
|
||||
For the most up to date news and community discussions, please join us on Discord!
|
||||
|
||||
If you find this tool useful, please consider subscribing to the project on [Patreon](https://www.patreon.com/everydream) or a one-time donation at [Ko-fi](https://ko-fi.com/everydream).
|
||||
[![Discord!](.github/discord_sm.png)](https://discord.gg/uheqxU6sXN)
|
||||
|
||||
If you find this tool useful, please consider subscribing to the project on Patreon or a one-time donation on Ko-fi. Your donations keep this project alive as a free open source tool with ongoing enhancements.
|
||||
|
||||
[![Patreon](.github/patreon-medium-button.png)](https://www.patreon.com/everydream) or [![Kofi](.github/kofibutton_sm.png)](https://ko-fi.com/everydream).
|
||||
|
||||
If you're coming from Dreambooth, please [read this](doc/NOTDREAMBOOTH.md) for an explanation of why EveryDream is not Dreambooth.
|
||||
|
||||
|
@ -22,16 +26,18 @@ Single GPU is currently supported
|
|||
|
||||
32GB of system RAM recommended for 50k+ training images, but may get away with sufficient swap file and 16GB
|
||||
|
||||
Ampere or newer 24GB+ (3090/A5000/4090, etc) recommended for 10k+ images unless you want to wait a long time
|
||||
Ampere or newer 24GB+ (3090/A5000/4090, etc) recommended for 10k+ images
|
||||
|
||||
...Or use any computer with a web browser and run on Vast/Runpod/Colab. See [Cloud](#cloud) section below.
|
||||
...Or use any computer with a web browser and run on Vast/Colab. See [Cloud](#cloud) section below.
|
||||
|
||||
## Video tutorials
|
||||
|
||||
### [Basic setup and getting started](https://www.youtube.com/watch?v=OgpJK8SUW3c)
|
||||
|
||||
Covers install, setup of base models, startning training, basic tweaking, and looking at your logs
|
||||
### [Multiaspect and crop jitter](https://www.youtube.com/watch?v=0xswM8QYFD0)
|
||||
|
||||
### [Multiaspect and crop jitter explainer](https://www.youtube.com/watch?v=0xswM8QYFD0)
|
||||
|
||||
|
||||
Behind the scenes look at how the trainer handles multiaspect and crop jitter
|
||||
|
||||
|
@ -39,6 +45,16 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter
|
|||
|
||||
Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.
|
||||
|
||||
## Cloud/Docker
|
||||
|
||||
### [Free tier Google Colab notebook](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb)
|
||||
|
||||
### * [RunPod / Vast Instructions](/doc/CLOUD_SETUP.md)
|
||||
#### * [Vast.ai Video Tutorial](https://www.youtube.com/watch?v=PKQesb4om9I)
|
||||
#### [Runpod Video Tutorial](https://www.youtube.com/watch?v=XAULP-4hsnA)
|
||||
|
||||
### [Docker image link](https://github.com/victorchall/EveryDream2trainer/pkgs/container/everydream2trainer)
|
||||
|
||||
## Docs
|
||||
|
||||
[Setup and installation](doc/SETUP.md)
|
||||
|
@ -49,28 +65,21 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
|
|||
|
||||
[Training](doc/TRAINING.md) - How to start training
|
||||
|
||||
[Troubleshooting](doc/TROUBLESHOOTING.md)
|
||||
|
||||
[Basic Tweaking](doc/TWEAKING.md) - Important args to understand to get started
|
||||
|
||||
[Logging](doc/LOGGING.md)
|
||||
|
||||
[Advanced Tweaking](doc/ATWEAKING.md) - More stuff to tweak once you are comfortable
|
||||
|
||||
[Advanced Optimizer Tweaking](/doc/OPTIMIZER.md) - Even more stuff to tweak if you are *very adventurous*
|
||||
[Advanced Tweaking](doc/ADVANCED_TWEAKING.md) and [Advanced Optimizer Tweaking](/doc/OPTIMIZER.md)
|
||||
|
||||
[Chaining training sessions](doc/CHAINING.md) - Modify training parameters by chaining training sessions together end to end
|
||||
|
||||
[Shuffling Tags](doc/SHUFFLING_TAGS.md)
|
||||
|
||||
[Data Balancing](doc/BALANCING.md) - Includes my small treatise on model preservation with ground truth data
|
||||
[Data Balancing](doc/BALANCING.md) - Includes my small treatise on model "preservation" with additional ground truth data
|
||||
|
||||
[Logging](doc/LOGGING.md)
|
||||
|
||||
[Validation](doc/VALIDATION.md) - Use a validation split on your data to see when you are overfitting and tune hyperparameters
|
||||
|
||||
[Troubleshooting](doc/TROUBLESHOOTING.md)
|
||||
[Contributing](doc/CONTRIBUTING.md)
|
||||
|
||||
## Cloud
|
||||
|
||||
[Free tier Google Colab notebook](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb)
|
||||
|
||||
[RunPod / Vast](/doc/CLOUD_SETUP.md)
|
||||
|
||||
[Docker image link](https://github.com/victorchall/EveryDream2trainer/pkgs/container/everydream2trainer)
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/Train_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
|
@ -33,49 +33,88 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"#@title # Install python 3.10 \n",
|
||||
"#@markdown # This will show a runtime error, its ok, its on purpose to restart the kernel to update python.\n",
|
||||
"#@markdown # This will show a runtime error, it's ok, it's on purpose to restart the kernel to update python.\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import sys\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"!wget https://github.com/korakot/kora/releases/download/v0.10/py310.sh\n",
|
||||
"!bash ./py310.sh -b -f -p /usr/local\n",
|
||||
"!python -m ipykernel install --name \"py310\" --user\n",
|
||||
"clear_output()\n",
|
||||
"time.sleep(1) #needed to clear is before kill\n",
|
||||
"os.kill(os.getpid(), 9)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "f2cdMtCt9Wb6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Verify python version, should be 3.10.something\n",
|
||||
"!python --version"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "d1di4EC6ygw1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Optional connect Gdrive\n",
|
||||
"#@markdown # But strongly recommended\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"#@markdown Optional connect Gdrive But strongly recommended\n",
|
||||
"#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
|
||||
"\n",
|
||||
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
|
||||
"from google.colab import drive\n",
|
||||
"drive.mount('/content/drive')\n",
|
||||
"Mount_to_Gdrive = True #@param{type:\"boolean\"} \n",
|
||||
"\n",
|
||||
"!mkdir -p /content/drive/MyDrive/everydreamlogs/ckpt"
|
||||
"if Mount_to_Gdrive:\n",
|
||||
" from google.colab import drive\n",
|
||||
" drive.mount('/content/drive')\n",
|
||||
"\n",
|
||||
" !mkdir -p /content/drive/MyDrive/everydreamlogs/ckpt\n",
|
||||
"\n",
|
||||
"# Define a custom function to display a progress bar\n",
|
||||
"def display_progress_bar(progress, total, prefix=\"\"):\n",
|
||||
" sys.stdout.write(f\"\\r{prefix}[{'=' * progress}>{' ' * (total - progress - 1)}] {progress + 1}/{total}\")\n",
|
||||
" sys.stdout.flush()\n",
|
||||
"\n",
|
||||
"total_steps = 9\n",
|
||||
"current_step = 0\n",
|
||||
"\n",
|
||||
"!pip install transformers==4.25.1 --progress-bar on --quiet\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"!pip install watchdog --progress-bar on --quiet\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"!pip install matplotlib --progress-bar on --quiet\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"# Install the alive-package library\n",
|
||||
"!pip install alive-progress --progress-bar on --quiet\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Install the tqdm library\n",
|
||||
"!pip install tqdm --progress-bar on --quiet\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"# Download the py310.sh script\n",
|
||||
"!wget https://github.com/korakot/kora/releases/download/v0.10/py310.sh -q\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"# Run the py310.sh script\n",
|
||||
"try:\n",
|
||||
" output = os.popen('bash ./py310.sh -b -f -p /usr/local 2>&1').read()\n",
|
||||
" total_lines = len(output.splitlines())\n",
|
||||
" for i, line in enumerate(output.splitlines()):\n",
|
||||
" clear_output(wait=True)\n",
|
||||
" display_progress_bar(i, total_lines, \"install progress:\")\n",
|
||||
"except Exception as e:\n",
|
||||
" print(str(e))\n",
|
||||
"\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"# Install the py310 kernel\n",
|
||||
"!python -m ipykernel install --name \"py310\" --user > /dev/null 2>&1\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"\n",
|
||||
"# Clear output\n",
|
||||
"!rm /content/py310.sh\n",
|
||||
"current_step += 1\n",
|
||||
"display_progress_bar(current_step, total_steps, \"install progress:\")\n",
|
||||
"clear_output()\n",
|
||||
"time.sleep(1) #needed to clear is before kill\n",
|
||||
"os.kill(os.getpid(), 9)\n",
|
||||
"print(\"\\nInstallation completed.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -87,36 +126,52 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@markdown # Install Dependencies\n",
|
||||
"#@markdown # Finish Install Dependencies into the new python\n",
|
||||
"#@markdown This will take a couple minutes, be patient and watch the output for \"DONE!\"\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"from subprocess import getoutput\n",
|
||||
"s = getoutput('nvidia-smi')\n",
|
||||
"!pip install -q torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url \"https://download.pytorch.org/whl/cu117\"\n",
|
||||
"!pip install -q transformers==4.25.1\n",
|
||||
"!pip install -q diffusers[torch]==0.13.0\n",
|
||||
"!pip install -q pynvml==11.4.1\n",
|
||||
"!pip install -q bitsandbytes==0.35.0\n",
|
||||
"!pip install -q ftfy==6.1.1\n",
|
||||
"!pip install -q aiohttp==3.8.3\n",
|
||||
"!pip install -q tensorboard>=2.11.0\n",
|
||||
"!pip install -q protobuf==3.20.1\n",
|
||||
"!pip install -q wandb==0.13.6\n",
|
||||
"!pip install -q pyre-extensions==0.0.23\n",
|
||||
"!pip install -q xformers==0.0.16\n",
|
||||
"!pip install -q pytorch-lightning==1.6.5\n",
|
||||
"!pip install -q OmegaConf==2.2.3\n",
|
||||
"!pip install -q numpy==1.23.5\n",
|
||||
"!pip install -q colorama\n",
|
||||
"!pip install -q keyboard\n",
|
||||
"!pip install -q triton\n",
|
||||
"!pip install -q lion-pytorch\n",
|
||||
"import subprocess\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"\n",
|
||||
"packages = [\n",
|
||||
" ('torch==1.13.1+cu117 torchvision==0.14.1+cu117', 'https://download.pytorch.org/whl/cu117'),\n",
|
||||
" 'transformers==4.25.1',\n",
|
||||
" 'diffusers[torch]==0.13.0',\n",
|
||||
" 'pynvml==11.4.1',\n",
|
||||
" 'bitsandbytes==0.35.0',\n",
|
||||
" 'ftfy==6.1.1',\n",
|
||||
" 'aiohttp==3.8.3',\n",
|
||||
" 'tensorboard>=2.11.0',\n",
|
||||
" 'protobuf==3.20.1',\n",
|
||||
" 'wandb==0.13.6',\n",
|
||||
" 'pyre-extensions==0.0.23',\n",
|
||||
" 'xformers==0.0.16',\n",
|
||||
" 'pytorch-lightning==1.6.5',\n",
|
||||
" 'OmegaConf==2.2.3',\n",
|
||||
" 'numpy==1.23.5',\n",
|
||||
" 'colorama',\n",
|
||||
" 'keyboard',\n",
|
||||
" 'triton',\n",
|
||||
" 'lion-pytorch',\n",
|
||||
" 'compel'\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for package in tqdm(packages, desc='Installing packages', unit='package'):\n",
|
||||
" if isinstance(package, tuple):\n",
|
||||
" package_name, extra_index_url = package\n",
|
||||
" cmd = f\"pip install -q {package_name} --extra-index-url {extra_index_url}\"\n",
|
||||
" else:\n",
|
||||
" cmd = f\"pip install -q {package}\"\n",
|
||||
" \n",
|
||||
" subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)\n",
|
||||
"\n",
|
||||
"clear_output()\n",
|
||||
"\n",
|
||||
"!git clone https://github.com/victorchall/EveryDream2trainer.git\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"!python utils/get_yamls.py\n",
|
||||
"clear_output()\n",
|
||||
"print(\"DONE!\")"
|
||||
"print(\"DONE! installing dependcies make sure we are using python 3.10.x\")\n",
|
||||
"!python --version"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -129,128 +184,86 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Get A Base Model\n",
|
||||
"#@markdown Choose SD1.5 or Waifu Diffusion 1.3 from the dropdown, or paste your own URL in the box\n",
|
||||
"#@markdown Choose SD1.5, Waifu Diffusion 1.3, SD2.1, or 2.1(512) from the dropdown, or paste your own URL in the box\n",
|
||||
"#@markdown * alternately you can link to a HF repo using NAME/MODEL\n",
|
||||
"#@markdown * link to a set of diffusers on your Gdrive\n",
|
||||
"#@markdown * paste a url, atm there is no support for .safetensors\n",
|
||||
"\n",
|
||||
"#@markdown If you already did this once with Gdrive connected, you can skip this step as the cached copy is on your gdrive\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"!mkdir input\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"MODEL_URL = \"https://huggingface.co/panopstor/EveryDream/resolve/main/sd_v1-5_vae.ckpt\" #@param [\"https://huggingface.co/panopstor/EveryDream/resolve/main/sd_v1-5_vae.ckpt\", \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
|
||||
"print(\"Downloading \")\n",
|
||||
"!wget $MODEL_URL\n",
|
||||
"MODEL_LOCATION = \"sd_v1-5+vae.ckpt\" #@param [\"sd_v1-5+vae.ckpt\", \"hakurei/waifu-diffusion-v1-3\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
|
||||
"if MODEL_LOCATION == \"sd_v1-5+vae.ckpt\":\n",
|
||||
" MODEL_LOCATION = \"panopstor/EveryDream\"\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"download_path = \"\"\n",
|
||||
"\n",
|
||||
"if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION: #maybe just add a radio button to download this should work for now\n",
|
||||
" print(\"Downloading \")\n",
|
||||
" !wget $MODEL_LOCATION\n",
|
||||
" clear_output()\n",
|
||||
" print(\"DONE!\")\n",
|
||||
" download_path = os.path.join(os.getcwd(), os.path.basename(MODEL_URL))\n",
|
||||
"\n",
|
||||
"else:\n",
|
||||
" save_name = MODEL_LOCATION\n",
|
||||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"#@markdown * if you chose to link to a .ckpt Select the correct model version in the drop down menu for conversion\n",
|
||||
"\n",
|
||||
"clear_output()\n",
|
||||
"print(\"DONE!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nEzuEYH0536C"
|
||||
},
|
||||
"source": [
|
||||
"In order to train, you need a base model on which to train. This is a one-time setup to configure base models when you want to use a particular base. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "tPvQSo6ScF2c"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"#@title Setup conversion\n",
|
||||
"inference_yaml = \" \"\n",
|
||||
"\n",
|
||||
"#@markdown **If you already did this once with Gdrive connected, you can skip this step as the cached copy is on your gdrive.** \n",
|
||||
"# \n",
|
||||
"# If you are not sure, look in your Gdrive for `everydreamlogs/ckpt` and see if you have a folder with the `save_name` below.\n",
|
||||
"# Check if the downloaded or copied model is a .ckpt file\n",
|
||||
"#@markdown is the model 1.5 or 2.1 based\n",
|
||||
"if download_path.endswith(\".ckpt\"):\n",
|
||||
" model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
|
||||
" save_path = download_path\n",
|
||||
" if \".ckpt\" in save_name:\n",
|
||||
" save_name = save_name.replace(\".ckpt\", \"\")\n",
|
||||
"\n",
|
||||
"#@markdown Pick the `model_type` in the dropdown. This is the model type that you are converting and you downloaded above. This is important as it will determine the model architecture and the correct settings to use.\n",
|
||||
"\n",
|
||||
"#@markdown * `SD1x` is all SD1.x based models *(SD1.4, SD1.5, Waifu Diffusion 1.3, etc)*\n",
|
||||
"\n",
|
||||
"#@markdown * `SD2_512_base` is the SD2 512 base model\n",
|
||||
"\n",
|
||||
"#@markdown * `SD21` is all SD2 768 models. *(ex. SD2.1 768, or trained models based on that)*\n",
|
||||
"\n",
|
||||
"#@markdown If you are not sure, double check the model author's page or ask for help on [Discord](https://discord.gg/uheqxU6sXN).\n",
|
||||
"model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
|
||||
"\n",
|
||||
"#@markdown This is the temporary ckpt file that was downloaded above. If you downloaded a different model, you can change this. *Hint: look at your file manager in the EveryDream2trainer folder for .ckpt files*.\n",
|
||||
"base_path = \"/content/EveryDream2trainer/sd_v1-5_vae.ckpt\" #@param {type:\"string\"}\n",
|
||||
"\n",
|
||||
"#@markdown The name that you will use when selecting this model in the future training sessons.\n",
|
||||
"save_name = \"SD15\" #@param{type:\"string\"}\n",
|
||||
"\n",
|
||||
"#@markdown If you are using Gdrive, this will save the converted model to your Gdrive for future use so you can skip downloading and converting the model.\n",
|
||||
"cache_to_gdrive = True #@param{type:\"boolean\"}\n",
|
||||
"\n",
|
||||
"if cache_to_gdrive:\n",
|
||||
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
|
||||
"\n",
|
||||
"img_size = 512\n",
|
||||
"upscale_attention = False\n",
|
||||
"prediction_type = \"epsilon\"\n",
|
||||
"if model_type == \"SD1x\":\n",
|
||||
" img_size = 512\n",
|
||||
" upscale_attention = False\n",
|
||||
" prediction_type = \"epsilon\"\n",
|
||||
" if model_type == \"SD1x\":\n",
|
||||
" inference_yaml = \"v1-inference.yaml\"\n",
|
||||
"elif model_type == \"SD2_512_base\":\n",
|
||||
" elif model_type == \"SD2_512_base\":\n",
|
||||
" upscale_attention = True\n",
|
||||
" inference_yaml = \"v2-inference.yaml\"\n",
|
||||
"elif model_type == \"SD21\":\n",
|
||||
" elif model_type == \"SD21\":\n",
|
||||
" upscale_attention = True\n",
|
||||
" prediction_type = \"v_prediction\"\n",
|
||||
" inference_yaml = \"v2-inference-v.yaml\"\n",
|
||||
" img_size = 768\n",
|
||||
"\n",
|
||||
"print(base_path)\n",
|
||||
"print(inference_yaml)\n",
|
||||
" !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
|
||||
" --original_config_file $inference_yaml \\\n",
|
||||
" --image_size $img_size \\\n",
|
||||
" --checkpoint_path $save_path \\\n",
|
||||
" --prediction_type $prediction_type \\\n",
|
||||
" --upcast_attn False \\\n",
|
||||
" --dump_path $save_name\n",
|
||||
"\n",
|
||||
"!python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
|
||||
"--original_config_file {inference_yaml} \\\n",
|
||||
"--image_size {img_size} \\\n",
|
||||
"--checkpoint_path {base_path} \\\n",
|
||||
"--prediction_type {prediction_type} \\\n",
|
||||
"--upcast_attn False \\\n",
|
||||
"--dump_path {save_name}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "bLpcvpGJB4Gu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Pick your base model from a diffusers model saved to your Gdrive (converted above)\n",
|
||||
"\n",
|
||||
"#@markdown Do not skip this cell.\n",
|
||||
"\n",
|
||||
"#@markdown * If you have preveiously saved diffusers on your drive you can select it here\n",
|
||||
"\n",
|
||||
"#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n",
|
||||
"\n",
|
||||
"#@markdown The default for SD1.5 converted above would be */content/drive/MyDrive/everydreamlogs/ckpt/SD15*\n",
|
||||
"Resume_Model = \"/content/drive/MyDrive/everydreamlogs/ckpt/SD15\" #@param{type:\"string\"} \n",
|
||||
"save_name = Resume_Model"
|
||||
" # Set the save path to the GDrive directory if cache_to_gdrive is True\n",
|
||||
" if cache_to_gdrive:\n",
|
||||
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
|
||||
"if inference_yaml != \" \":\n",
|
||||
" print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
|
||||
"print(\"Model \" + save_name + \" will be used!\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "JXVu-W2lCjwX"
|
||||
"id": "EHyFzKWXX9kB"
|
||||
},
|
||||
"source": [
|
||||
"# Training\n",
|
||||
"\n",
|
||||
"For a more indepth Explanation of each of these paramaters check out /content/EveryDream2trainer/doc.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"After youve tried a few models you will find /content/EveryDream2trainer/doc/ATWEAKING.md to be extremly helpful."
|
||||
"After youve tried a few models you will find /content/EveryDream2trainer/doc/ADVANCED_TWEAKING.md to be extremly helpful."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -262,13 +275,23 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title \n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"#@markdown # Run Everydream 2\n",
|
||||
"#@markdown If you want to use a .json config or upload your own, skip this cell and run the cell below instead\n",
|
||||
"from google.colab import runtime\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"import time\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import time\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"#@title #Run Everydream 2\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"#@markdown If you want to use a .json config or upload your own, skip this cell and run the cell below instead\n",
|
||||
"!rm -r /content/EveryDream2trainer/Training_Data\n",
|
||||
"#@markdown * Save logs and output ckpts to Gdrive (strongly suggested)\n",
|
||||
"Save_to_Gdrive = True #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * Disconnect after training to save Credits \n",
|
||||
"Disconnect_after_training = False #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * Use resume to contnue training you just ran, will also find latest diffusers log in your Gdrive to continue.\n",
|
||||
"resume = False #@param{type:\"boolean\"}\n",
|
||||
"#@markdown * Name your project so you can find it in your logs\n",
|
||||
|
@ -331,6 +354,8 @@
|
|||
"#@markdown * Using the same seed each time you train allows for more accurate a/b comparison of models, leave at -1 for random\n",
|
||||
"#@markdown * The seed also effects your training samples, if you want the same seed each sample you will need to change it from -1\n",
|
||||
"Training_Seed = -1 #@param{type:\"integer\"}\n",
|
||||
"#@markdown * warm up steps are useful for validation and cosine lrs\n",
|
||||
"Lr_warmup_steps = 20 #@param{type:\"integer\"}\n",
|
||||
"#@markdown * use this option to configure a sample_prompts.json\n",
|
||||
"#@markdown * check out /content/EveryDream2trainer/doc/logging.md. for more details\n",
|
||||
"Advance_Samples = False #@param{type:\"boolean\"}\n",
|
||||
|
@ -362,11 +387,28 @@
|
|||
" !wandb login $wandb_token\n",
|
||||
" wandb_settings = \"--wandb\"\n",
|
||||
"\n",
|
||||
"if \"zip\" in Dataset_Location:\n",
|
||||
" !rm -r /Training_Data/\n",
|
||||
" !mkdir Training_Data\n",
|
||||
" !unzip $Dataset_Location -d /Training_Data\n",
|
||||
" Dataset_Location = \"/Training_Data\"\n",
|
||||
"#@markdown use validation with wandb\n",
|
||||
"\n",
|
||||
"validatation = False #@param{type:\"boolean\"}\n",
|
||||
"validate = \"\"\n",
|
||||
"if validatation:\n",
|
||||
" validate = \"--validation_config validation_default.json\"\n",
|
||||
"\n",
|
||||
"extensions = ['.zip', '.7z', '.rar', '.tgz']\n",
|
||||
"uncompressed_dir = 'Training_Data'\n",
|
||||
"\n",
|
||||
"if any(ext in Dataset_Location for ext in extensions):\n",
|
||||
" # Create the uncompressed directory if it doesn't exist\n",
|
||||
" if not os.path.exists(uncompressed_dir):\n",
|
||||
" os.makedirs(uncompressed_dir)\n",
|
||||
" \n",
|
||||
" # Extract the compressed file to the uncompressed directory\n",
|
||||
" shutil.unpack_archive(Dataset_Location, uncompressed_dir)\n",
|
||||
"\n",
|
||||
" # Set the dataset location to the uncompressed directory\n",
|
||||
" Dataset_Location = uncompressed_dir\n",
|
||||
"\n",
|
||||
"# Use the dataset location in the rest of your code\n",
|
||||
"dataset = Dataset_Location\n",
|
||||
"\n",
|
||||
"Drive=\"\"\n",
|
||||
|
@ -382,8 +424,6 @@
|
|||
"Gradient = \"\"\n",
|
||||
"if Gradient_checkpointing:\n",
|
||||
" Gradient = \"--gradient_checkpointing \"\n",
|
||||
"if \"A100\" in s:\n",
|
||||
" Gradient = \"\"\n",
|
||||
"\n",
|
||||
"DX = \"\" \n",
|
||||
"if Disable_Xformers:\n",
|
||||
|
@ -393,35 +433,72 @@
|
|||
"if shuffle_tags:\n",
|
||||
" shuffle = \"--shuffle_tags \"\n",
|
||||
"\n",
|
||||
"def parse_progress(log_line):\n",
|
||||
" match = re.search(r'\\((\\d+)%\\)', log_line)\n",
|
||||
" if match:\n",
|
||||
" return int(match.group(1))\n",
|
||||
" return None\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"textencode = \"\"\n",
|
||||
"if Disable_text_Encoder:\n",
|
||||
" textencode = \"--disable_textenc_training\"\n",
|
||||
"\n",
|
||||
"!python train.py --resume_ckpt \"$model\" \\\n",
|
||||
" $textencode \\\n",
|
||||
" $Gradient \\\n",
|
||||
" $shuffle \\\n",
|
||||
" $Drive \\\n",
|
||||
" $DX \\\n",
|
||||
" $wandb_settings \\\n",
|
||||
" --amp \\\n",
|
||||
" --clip_skip $Clip_skip \\\n",
|
||||
" --batch_size $Batch_Size \\\n",
|
||||
" --grad_accum $Gradient_steps \\\n",
|
||||
" --cond_dropout $Conditional_DropOut \\\n",
|
||||
" --data_root \"$dataset\" \\\n",
|
||||
" --flip_p $Picture_flip \\\n",
|
||||
" --lr $Learning_Rate \\\n",
|
||||
" --lr_scheduler \"$Schedule\" \\\n",
|
||||
" --max_epochs $Max_Epochs \\\n",
|
||||
" --project_name \"$Project_Name\" \\\n",
|
||||
" --resolution $Resolution \\\n",
|
||||
" --sample_prompts \"$Sample_File\" \\\n",
|
||||
" --sample_steps $Steps_between_samples \\\n",
|
||||
" --save_every_n_epoch $Save_every_N_epoch \\\n",
|
||||
" --seed $Training_Seed \\\n",
|
||||
" --zero_frequency_noise_ratio $zero_frequency_noise\n",
|
||||
"\n"
|
||||
"def update_progress_bar(progress: float):\n",
|
||||
" print(\"Training progress: {:.2f}%\".format(progress))\n",
|
||||
" print(\"[{0}{1}]\".format('#' * int(progress // 2), ' ' * (50 - int(progress // 2))))\n",
|
||||
"\n",
|
||||
"# Start the training process and capture the output\n",
|
||||
"command = f\"\"\"python train.py --resume_ckpt \"{model}\" \\\n",
|
||||
" {textencode} \\\n",
|
||||
" {Gradient} \\\n",
|
||||
" {shuffle} \\\n",
|
||||
" {Drive} \\\n",
|
||||
" {DX} \\\n",
|
||||
" {validate} \\\n",
|
||||
" {wandb_settings} \\\n",
|
||||
" --clip_skip {Clip_skip} \\\n",
|
||||
" --batch_size {Batch_Size} \\\n",
|
||||
" --grad_accum {Gradient_steps} \\\n",
|
||||
" --cond_dropout {Conditional_DropOut} \\\n",
|
||||
" --data_root \"{dataset}\" \\\n",
|
||||
" --flip_p {Picture_flip} \\\n",
|
||||
" --lr {Learning_Rate} \\\n",
|
||||
" --log_step 25 \\\n",
|
||||
" --lr_warmup_steps {Lr_warmup_steps} \\\n",
|
||||
" --lr_scheduler \"{Schedule}\" \\\n",
|
||||
" --max_epochs {Max_Epochs} \\\n",
|
||||
" --project_name \"{Project_Name}\" \\\n",
|
||||
" --resolution {Resolution} \\\n",
|
||||
" --sample_prompts \"{Sample_File}\" \\\n",
|
||||
" --sample_steps {Steps_between_samples} \\\n",
|
||||
" --save_every_n_epoch {Save_every_N_epoch} \\\n",
|
||||
" --seed {Training_Seed} \\\n",
|
||||
" --zero_frequency_noise_ratio {zero_frequency_noise}\"\"\"\n",
|
||||
"\n",
|
||||
"process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n",
|
||||
"\n",
|
||||
"# Initialize the progress bar\n",
|
||||
"progress_bar = tqdm(total=100, desc=\"Training progress\", ncols=100)\n",
|
||||
"\n",
|
||||
"for log_line in process.stdout:\n",
|
||||
" global last_output\n",
|
||||
" last_output = None\n",
|
||||
" log_line = log_line.strip()\n",
|
||||
" if log_line:\n",
|
||||
" if log_line != last_output:\n",
|
||||
" progress = parse_progress(log_line)\n",
|
||||
" if progress is not None:\n",
|
||||
" update_progress_bar(progress)\n",
|
||||
" else:\n",
|
||||
" print(log_line)\n",
|
||||
" last_output = log_line\n",
|
||||
"\n",
|
||||
"# Finish the training process\n",
|
||||
"process.wait()\n",
|
||||
"if Disconnect_after_training:\n",
|
||||
" time.sleep(40)\n",
|
||||
" runtime.unassign()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -456,7 +533,6 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "8HmIWtODuE6p"
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -490,8 +566,8 @@
|
|||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"include_colab_link": true
|
||||
"include_colab_link": true,
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
|
|
|
@ -32,6 +32,28 @@
|
|||
"Come visit us at [EveryDream Discord](https://discord.gg/uheqxU6sXN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ffff47f7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Optional Speed Test\n",
|
||||
"If all goes well you may find yourself downloading (or pushing to the cloud) 2-8GB of model data per saved checkpoint. Make sure your pod is not a dud. ~1000Mbit/s up/dn is probably good, though the location of the pod also makes a difference.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "934ba107",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import speedtest\n",
|
||||
"st = speedtest.Speedtest()\n",
|
||||
"print(f\"Your download speed: {round(st.download() / 1000 / 1000, 1)} Mbit/s\")\n",
|
||||
"print(f\"Your upload speed: {round(st.upload() / 1000 / 1000, 1)} Mbit/s\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c73894e-3b5e-4268-9f83-ed89bd4569f2",
|
||||
|
@ -51,10 +73,72 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3d9b0db8-c2b1-4f0a-b835-b6b2ef527019",
|
||||
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### HuggingFace Login\n",
|
||||
"# Start Training\n",
|
||||
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
|
||||
"\n",
|
||||
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
|
||||
"\n",
|
||||
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
|
||||
"\n",
|
||||
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
|
||||
"\n",
|
||||
"You can watch for test images in the logs folder.\n",
|
||||
"\n",
|
||||
"#### Weights and Balanaces\n",
|
||||
"I you pass the `--wandb` flag you will be prompted for your W&B `API Key`. W&B is a free online logging utility. If you don't have a W&B account, you can create one for free at https://wandb.ai/site. Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
|
||||
"metadata": {
|
||||
"scrolled": true,
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config train.json \\\n",
|
||||
"--resume_ckpt \"panopstor/EveryDream\" \\\n",
|
||||
"--project_name \"sd1_mymodel\" \\\n",
|
||||
"--data_root \"input\" \\\n",
|
||||
"--max_epochs 200 \\\n",
|
||||
"--sample_steps 150 \\\n",
|
||||
"--save_every_n_epochs 35 \\\n",
|
||||
"--lr 1.2e-6 \\\n",
|
||||
"--lr_scheduler constant \\\n",
|
||||
"--save_full_precision\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"panopstor/EveryDream\"\n",
|
||||
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
|
||||
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3c506e79-bf03-4e34-bf06-9371963d4d7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# HuggingFace download (Optional)\n",
|
||||
"Run the cell below and paste your token into the prompt. You can get your token from your [huggingface account page](https://huggingface.co/settings/tokens).\n",
|
||||
"\n",
|
||||
"The token will not show on the screen, just press enter after you paste it."
|
||||
|
@ -74,10 +158,10 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7a96f2af-8c93-4460-aa9e-2ff795fb06ea",
|
||||
"id": "b252a308-49cf-443f-abbb-d08b471411fb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Then run the following cell to download the base checkpoint (may take a minute)."
|
||||
"Then run the following cell to download the base checkpoint (may take a minute)."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -111,68 +195,6 @@
|
|||
"print(\"DONE\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Start Training\n",
|
||||
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
|
||||
"\n",
|
||||
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
|
||||
"\n",
|
||||
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
|
||||
"\n",
|
||||
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
|
||||
"\n",
|
||||
"You can watch for test images in the logs folder.\n",
|
||||
"\n",
|
||||
"## Weights and Balanaces\n",
|
||||
"I you pass the `--wandb` flag you will be prompted for your W&B `API Key`. W&B is a free online logging utility. If you don't have a W&B account, you can create one for free at https://wandb.ai/site. Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
|
||||
"metadata": {
|
||||
"scrolled": true,
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config train.json \\\n",
|
||||
"--resume_ckpt \"sd_v1-5_vae\" \\\n",
|
||||
"--project_name \"sd1_mymodel\" \\\n",
|
||||
"--data_root \"input\" \\\n",
|
||||
"--max_epochs 200 \\\n",
|
||||
"--sample_steps 150 \\\n",
|
||||
"--save_every_n_epochs 35 \\\n",
|
||||
"--lr 1.2e-6 \\\n",
|
||||
"--lr_scheduler constant \\\n",
|
||||
"--save_full_precision\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"{ckpt_name}\"\n",
|
||||
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
|
||||
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f24eee3d-f5df-45f3-9acc-ee0206cfe6b1",
|
||||
|
@ -351,7 +373,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -365,7 +387,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.10"
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
Copyright [2022-2023] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import requests
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
|
||||
|
||||
import torch
|
||||
from pynvml import *
|
||||
|
||||
import time
|
||||
from colorama import Fore, Style
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
|
||||
def get_gpu_memory_map():
|
||||
"""Get the current gpu usage.
|
||||
Returns
|
||||
-------
|
||||
usage: dict
|
||||
Keys are device ids as integers.
|
||||
Values are memory usage as integers in MB.
|
||||
"""
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(0)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used/1024/1024
|
||||
|
||||
def create_blip2_processor(model_name, device, dtype=torch.float16):
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
args.model, torch_dtype=dtype
|
||||
)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
print(f"BLIP2 Model loaded: {model_name}")
|
||||
return processor, model
|
||||
|
||||
def create_git_processor(model_name, device, dtype=torch.float16):
|
||||
processor = GitProcessor.from_pretrained(model_name)
|
||||
model = GitForCausalLM.from_pretrained(
|
||||
args.model, torch_dtype=dtype
|
||||
)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
print(f"GIT Model loaded: {model_name}")
|
||||
return processor, model
|
||||
|
||||
def create_auto_processor(model_name, device, dtype=torch.float16):
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(
|
||||
args.model, torch_dtype=dtype
|
||||
)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
print("Auto Model loaded")
|
||||
return processor, model
|
||||
|
||||
def main(args):
|
||||
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
||||
dtype = torch.float32 if args.force_cpu else torch.float16
|
||||
|
||||
if "salesforce/blip2-" in args.model.lower():
|
||||
print(f"Using BLIP2 model: {args.model}")
|
||||
processor, model = create_blip2_processor(args.model, device, dtype)
|
||||
elif "microsoft/git-" in args.model.lower():
|
||||
print(f"Using GIT model: {args.model}")
|
||||
processor, model = create_git_processor(args.model, device, dtype)
|
||||
else:
|
||||
# try to use auto model? doesn't work with blip/git
|
||||
processor, model = create_auto_processor(args.model, device, dtype)
|
||||
|
||||
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
|
||||
|
||||
# os.walk all files in args.data_root recursively
|
||||
for root, dirs, files in os.walk(args.data_root):
|
||||
for file in files:
|
||||
#get file extension
|
||||
ext = os.path.splitext(file)[1]
|
||||
if ext.lower() in SUPPORTED_EXT:
|
||||
full_file_path = os.path.join(root, file)
|
||||
image = Image.open(full_file_path)
|
||||
start_time = time.time()
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt", max_new_tokens=args.max_new_tokens).to(device, dtype)
|
||||
|
||||
generated_ids = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
print(f"file: {file}, caption: {generated_text}")
|
||||
exec_time = time.time() - start_time
|
||||
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
|
||||
|
||||
# get bare name
|
||||
name = os.path.splitext(full_file_path)[0]
|
||||
#name = os.path.join(root, name)
|
||||
if not os.path.exists(name):
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(generated_text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"{Fore.CYAN}** Current supported models:{Style.RESET_ALL}")
|
||||
print(" microsoft/git-base-textcaps")
|
||||
print(" microsoft/git-large-textcaps")
|
||||
print(" microsoft/git-large-r-textcaps")
|
||||
print(" Salesforce/blip2-opt-2.7b - (9GB VRAM or recommend 32GB sys RAM)")
|
||||
print(" Salesforce/blip2-opt-2.7b-coco - (9GB VRAM or recommend 32GB sys RAM)")
|
||||
print(" Salesforce/blip2-opt-6.7b - (16.5GB VRAM or recommend 64GB sys RAM)")
|
||||
print(" Salesforce/blip2-opt-6.7b-coco - (16.5GB VRAM or recommend 64GB sys RAM)")
|
||||
print()
|
||||
print(f"{Fore.CYAN} * The following will likely not work on any consumer GPUs or require huge sys RAM on CPU:{Style.RESET_ALL}")
|
||||
print(" salesforce/blip2-flan-t5-xl")
|
||||
print(" salesforce/blip2-flan-t5-xl-coco")
|
||||
print(" salesforce/blip2-flan-t5-xxl")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
|
||||
parser.add_argument("--model", type=str, default="salesforce/blip2-opt-2.7b", help="model from huggingface, ex. 'salesforce/blip2-opt-2.7b'")
|
||||
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available, may be useful to run huge models if you have a lot of system memory")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=24, help="max length for generated captions")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"** Using model: {args.model}")
|
||||
print(f"** Captioning files in: {args.data_root}")
|
||||
main(args)
|
|
@ -13,6 +13,49 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
"""
|
||||
Notes:
|
||||
this is generated from an excel sheet and actual ratios are hand picked to
|
||||
spread out the ratios evenly to avoid having super-finely defined buckets
|
||||
Too many buckets means more "runt" steps with repeated images to fill batch than necessary
|
||||
ex. we do not need both 1.0:1 and 1.125:1, they're almost identical ratios
|
||||
Try to keep around <20 ratio buckets per resolution, should be plenty coverage everything between 1:1 and 4:1
|
||||
More finely defined buckets will reduce cropping at the expense of more runt steps
|
||||
"""
|
||||
|
||||
ASPECTS_1536 = [[1536,1536], # 2359296 1:1
|
||||
[1728,1344],[1344,1728], # 2322432 1.286:1
|
||||
[1792,1280],[1280,1792], # 2293760 1.4:1
|
||||
[2048,1152],[1152,2048], # 2359296 1.778:1
|
||||
[2304,1024],[1024,2304], # 2359296 2.25:1
|
||||
[2432,960],[960,2432], # 2334720 2.53:1
|
||||
[2624,896],[896,2624], # 2351104 2.929:1
|
||||
[2816,832],[832,2816], # 2342912 3.385:1
|
||||
[3072,768],[768,3072], # 2359296 4:1
|
||||
]
|
||||
|
||||
ASPECTS_1408 = [[1408,1408], # 1982464 1:1
|
||||
[1536,1280],[1280,1536], # 1966080 1.2:1
|
||||
[1664,1152],[1152,1664], # 1916928 1.444:1
|
||||
[1920,1024],[1024,1920], # 1966080 1.875:1
|
||||
[2048,960],[960,2048], # 1966080 2.133:1
|
||||
[2368,832],[832,2368], # 1970176 2.846:1
|
||||
[2560,768],[768,2560], # 1966080 3.333:1
|
||||
[2816,704],[704,3072], # 1982464 4:1
|
||||
]
|
||||
|
||||
ASPECTS_1280 = [[1280,1280], # 1638400 1:1
|
||||
[1408,1152],[1408,1344], # 1622016 1.222:1
|
||||
[1600,1024],[1024,1600], # 1638400 1.563:1
|
||||
[1792,896],[896,1792], # 1605632 2:1
|
||||
[1920,832],[832,1920], # 1597440 2.308:1
|
||||
[2112,768],[768,2112], # 1585152 2.75:1
|
||||
[2304,704],[704,2304], # 1622016 3.27:1
|
||||
[2560,640],[640,2560], # 1638400 4:1
|
||||
]
|
||||
|
||||
ASPECTS_1152 = [[1152,1152], # 1327104 1:1
|
||||
#[1216,1088],[1088,1216], # 1323008 1.118:1
|
||||
[1280,1024],[1024,1280], # 1310720 1.25:1
|
||||
|
@ -48,7 +91,7 @@ ASPECTS_1024 = [[1024,1024], # 1048576 1:1
|
|||
]
|
||||
|
||||
ASPECTS_960 = [[960,960], # 921600 1:1
|
||||
[1024,896],[896,1024], # 917504 1.143:1
|
||||
#[1024,896],[896,1024], # 917504 1.143:1
|
||||
[1088,832],[832,1088], # 905216 1.308:1
|
||||
[1152,768],[768,1152], # 884736 1.5:1
|
||||
[1280,704],[704,1280], # 901120 1.818:1
|
||||
|
@ -56,11 +99,11 @@ ASPECTS_960 = [[960,960], # 921600 1:1
|
|||
[1680,576],[576,1680], # 921600 2.778:1
|
||||
#[1728,512],[512,1728], # 884736 3.375:1
|
||||
[1792,512],[512,1792], # 917504 3.5:1
|
||||
[2048,448],[448,2048], # 917504 4.714:1
|
||||
[2048,448],[448,2048], # 917504 4.57:1
|
||||
]
|
||||
|
||||
ASPECTS_896 = [[896,896], # 802816 1:1
|
||||
[960,832],[832,960], # 798720 1.153:1
|
||||
#[960,832],[832,960], # 798720 1.153:1
|
||||
[1024,768],[768,1024], # 786432 1.333:1
|
||||
[1088,704],[704,1088], # 765952 1.545:1
|
||||
[1216,640],[640,1216], # 778240 1.9:1
|
||||
|
@ -155,7 +198,7 @@ ASPECTS_384 = [[384,384], # 147456 1:1
|
|||
ASPECTS_256 = [[256,256], # 65536 1:1
|
||||
[384,192],[192,384], # 73728 2:1
|
||||
[512,128],[128,512], # 65536 4:1
|
||||
]
|
||||
] # very few buckets available for 256 with 64 pixel increments
|
||||
|
||||
def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
|
||||
if resolution < 256:
|
||||
|
@ -174,6 +217,10 @@ def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
|
|||
print(f" *** Value must be between 512 and 1024")
|
||||
raise e
|
||||
|
||||
def get_supported_resolutions():
|
||||
all_image_sizes = __get_all_aspects()
|
||||
return list(map(lambda sizes: sizes[0][0], all_image_sizes))
|
||||
|
||||
def __get_all_aspects():
|
||||
return [ASPECTS_256,
|
||||
ASPECTS_384,
|
||||
|
@ -188,5 +235,43 @@ def __get_all_aspects():
|
|||
ASPECTS_960,
|
||||
ASPECTS_1024,
|
||||
ASPECTS_1088,
|
||||
ASPECTS_1152
|
||||
ASPECTS_1152,
|
||||
ASPECTS_1280,
|
||||
ASPECTS_1536,
|
||||
]
|
||||
|
||||
|
||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
|
||||
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
|
||||
if x <= 1:
|
||||
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
|
||||
else:
|
||||
b,a = farey_aspect_ratio_pair_lt1(1/x, max_denominator_value)
|
||||
return a,b
|
||||
|
||||
# adapted from https://www.johndcook.com/blog/2010/10/20/best-rational-approximation/
|
||||
def farey_aspect_ratio_pair_lt1(x: float, max_denominator_value: int):
|
||||
if x > 1:
|
||||
raise ValueError("x must be <1")
|
||||
a, b = 0, 1
|
||||
c, d = 1, 1
|
||||
while (b <= max_denominator_value and d <= max_denominator_value):
|
||||
mediant = float(a+c)/(b+d)
|
||||
if x == mediant:
|
||||
if b + d <= max_denominator_value:
|
||||
return a+c, b+d
|
||||
elif d > b:
|
||||
return c, d
|
||||
else:
|
||||
return a, b
|
||||
elif x > mediant:
|
||||
a, b = a+c, b+d
|
||||
else:
|
||||
c, d = a+c, b+d
|
||||
|
||||
if (b > max_denominator_value):
|
||||
return c, d
|
||||
else:
|
||||
return a, b
|
||||
|
||||
return farey_aspect_ratio_pair(bucket_wh[0]/bucket_wh[1], 32)
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
import os
|
||||
import logging
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from functools import total_ordering
|
||||
from attrs import define, field, Factory
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
from typing import Iterable
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||
|
||||
def overlay(overlay, base):
|
||||
return overlay if overlay is not None else base
|
||||
|
||||
def safe_set(val):
|
||||
if isinstance(val, str):
|
||||
return dict.fromkeys([val]) if val else dict()
|
||||
|
||||
if isinstance(val, Iterable):
|
||||
return dict.fromkeys((i for i in val if i is not None))
|
||||
|
||||
return val or dict()
|
||||
|
||||
@define(frozen=True)
|
||||
class Tag:
|
||||
value: str
|
||||
weight: float = field(default=1.0, converter=lambda x: x if x is not None else 1.0)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
if isinstance(data, str):
|
||||
return Tag(data)
|
||||
|
||||
if isinstance(data, dict):
|
||||
value = str(data.get("tag"))
|
||||
weight = data.get("weight")
|
||||
if value:
|
||||
return Tag(value, weight)
|
||||
|
||||
return None
|
||||
|
||||
@define
|
||||
class ImageConfig:
|
||||
# Captions
|
||||
main_prompts: dict[str, None] = field(factory=dict, converter=safe_set)
|
||||
rating: float = None
|
||||
max_caption_length: int = None
|
||||
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
||||
|
||||
# Options
|
||||
multiply: float = None
|
||||
cond_dropout: float = None
|
||||
flip_p: float = None
|
||||
|
||||
def merge(self, other):
|
||||
if other is None:
|
||||
return self
|
||||
|
||||
return ImageConfig(
|
||||
main_prompts=other.main_prompts | self.main_prompts,
|
||||
rating=overlay(other.rating, self.rating),
|
||||
max_caption_length=overlay(other.max_caption_length, self.max_caption_length),
|
||||
tags= other.tags | self.tags,
|
||||
multiply=overlay(other.multiply, self.multiply),
|
||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
# Parse standard yaml tag file (with options)
|
||||
parsed_cfg = ImageConfig(
|
||||
main_prompts=safe_set(data.get("main_prompt")),
|
||||
rating=data.get("rating"),
|
||||
max_caption_length=data.get("max_caption_length"),
|
||||
tags=safe_set(map(Tag.parse, data.get("tags", []))),
|
||||
multiply=data.get("multiply"),
|
||||
cond_dropout=data.get("cond_dropout"),
|
||||
flip_p=data.get("flip_p"),
|
||||
)
|
||||
|
||||
# Alternatively parse from dedicated `caption` attribute
|
||||
if cap_attr := data.get('caption'):
|
||||
parsed_cfg = parsed_cfg.merge(ImageConfig.parse(cap_attr))
|
||||
|
||||
return parsed_cfg
|
||||
|
||||
@classmethod
|
||||
def fold(cls, configs):
|
||||
acc = ImageConfig()
|
||||
for cfg in configs:
|
||||
acc = acc.merge(cfg)
|
||||
return acc
|
||||
|
||||
def ensure_caption(self):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_caption_text(cls, text: str):
|
||||
if not text:
|
||||
return ImageConfig()
|
||||
if os.path.isfile(text):
|
||||
return ImageConfig.from_file(text)
|
||||
|
||||
split_caption = list(map(str.strip, text.split(",")))
|
||||
return ImageConfig(
|
||||
main_prompts=split_caption[0],
|
||||
tags=map(Tag.parse, split_caption[1:])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: str):
|
||||
match ext(file):
|
||||
case '.jpg' | '.jpeg' | '.png' | '.bmp' | '.webp' | '.jfif':
|
||||
return ImageConfig(image=file)
|
||||
case ".json":
|
||||
return ImageConfig.from_dict(json.load(read_text(file)))
|
||||
case ".yaml" | ".yml":
|
||||
return ImageConfig.from_dict(yaml.safe_load(read_text(file)))
|
||||
case ".txt" | ".caption":
|
||||
return ImageConfig.from_caption_text(read_text(file))
|
||||
case _:
|
||||
return logging.warning(" *** Unrecognized config extension {ext}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, input):
|
||||
if isinstance(input, str):
|
||||
if os.path.isfile(input):
|
||||
return ImageConfig.from_file(input)
|
||||
else:
|
||||
return ImageConfig.from_caption_text(input)
|
||||
elif isinstance(input, dict):
|
||||
return ImageConfig.from_dict(input)
|
||||
|
||||
|
||||
@define()
|
||||
class Dataset:
|
||||
image_configs: dict[str, ImageConfig]
|
||||
|
||||
def __global_cfg(fileset):
|
||||
cfgs = []
|
||||
|
||||
for cfgfile in ['global.yaml', 'global.yml']:
|
||||
if cfgfile in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset[cfgfile]))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __local_cfg(fileset):
|
||||
cfgs = []
|
||||
if 'multiply.txt' in fileset:
|
||||
cfgs.append(ImageConfig(multiply=read_float(fileset['multiply.txt'])))
|
||||
if 'cond_dropout.txt' in fileset:
|
||||
cfgs.append(ImageConfig(cond_dropout=read_float(fileset['cond_dropout.txt'])))
|
||||
if 'flip_p.txt' in fileset:
|
||||
cfgs.append(ImageConfig(flip_p=read_float(fileset['flip_p.txt'])))
|
||||
if 'local.yaml' in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
|
||||
if 'local.yml' in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __sidecar_cfg(imagepath, fileset):
|
||||
cfgs = []
|
||||
for cfgext in ['.txt', '.caption', '.yml', '.yaml']:
|
||||
cfgfile = barename(imagepath) + cfgext
|
||||
if cfgfile in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset[cfgfile]))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
# Use file name for caption only as a last resort
|
||||
@classmethod
|
||||
def __ensure_caption(cls, cfg: ImageConfig, file: str):
|
||||
if cfg.main_prompts:
|
||||
return cfg
|
||||
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
||||
return cfg.merge(cap_cfg)
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, data_root):
|
||||
# Create a visitor that maintains global config stack
|
||||
# and accumulates image configs as it traverses dataset
|
||||
image_configs = {}
|
||||
def process_dir(files, parent_globals):
|
||||
fileset = {os.path.basename(f): f for f in files}
|
||||
global_cfg = parent_globals.merge(Dataset.__global_cfg(fileset))
|
||||
local_cfg = Dataset.__local_cfg(fileset)
|
||||
for img in filter(is_image, files):
|
||||
img_cfg = Dataset.__sidecar_cfg(img, fileset)
|
||||
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
||||
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
||||
return global_cfg
|
||||
|
||||
walk_and_visit(data_root, process_dir, ImageConfig())
|
||||
return Dataset(image_configs)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_path):
|
||||
"""
|
||||
Import a dataset definition from a JSON file
|
||||
"""
|
||||
image_configs = {}
|
||||
with open(json_path, encoding='utf-8', mode='r') as stream:
|
||||
for data in json.load(stream):
|
||||
img = data.get("image")
|
||||
cfg = Dataset.__ensure_caption(ImageConfig.parse(data), img)
|
||||
if not img:
|
||||
logging.warning(f" *** Error parsing json image entry in {json_path}: {data}")
|
||||
continue
|
||||
image_configs[img] = cfg
|
||||
return Dataset(image_configs)
|
||||
|
||||
def image_train_items(self, aspects):
|
||||
items = []
|
||||
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
||||
config = self.image_configs[image]
|
||||
|
||||
if len(config.main_prompts) > 1:
|
||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||
|
||||
if len(config.main_prompts) < 1:
|
||||
logging.warning(f" *** No main_prompts for image {image}")
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
||||
tags.append(tag.value)
|
||||
tag_weights.append(tag.weight)
|
||||
use_weights = len(set(tag_weights)) > 1
|
||||
|
||||
try:
|
||||
caption = ImageCaption(
|
||||
main_prompt=next(iter(config.main_prompts)),
|
||||
rating=config.rating or 1.0,
|
||||
tags=tags,
|
||||
tag_weights=tag_weights,
|
||||
max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
|
||||
use_weights=use_weights)
|
||||
|
||||
item = ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=aspects,
|
||||
pathname=os.path.abspath(image),
|
||||
flip_p=config.flip_p or 0.0,
|
||||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout
|
||||
)
|
||||
items.append(item)
|
||||
except Exception as e:
|
||||
logging.error(f" *** Error preloading image or caption for: {image}, error: {e}")
|
||||
raise e
|
||||
return items
|
|
@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
@ -104,7 +106,7 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
example["image"] = image_transforms(train_item["image"])
|
||||
|
||||
if random.random() > self.conditional_dropout:
|
||||
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
|
||||
example["tokens"] = self.tokenizer(example["caption"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
|
@ -132,6 +134,8 @@ class EveryDreamBatch(Dataset):
|
|||
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
||||
image_train_tmp.image = None # hack for now to avoid memory leak
|
||||
example["caption"] = image_train_tmp.caption
|
||||
if image_train_tmp.cond_dropout is not None:
|
||||
example["cond_dropout"] = image_train_tmp.cond_dropout
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
|
||||
return example
|
||||
|
@ -142,9 +146,9 @@ class EveryDreamBatch(Dataset):
|
|||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
batch_size= batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
num_workers=min(batch_size, os.cpu_count()),
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Any, Optional, Generator
|
||||
from argparse import Namespace
|
||||
|
||||
|
@ -20,6 +21,8 @@ from data import aspects
|
|||
from data.image_train_item import ImageTrainItem
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
||||
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||
|
@ -35,17 +38,35 @@ def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageT
|
|||
for i in items:
|
||||
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationDataset:
|
||||
name: str
|
||||
dataloader: torch.utils.data.DataLoader
|
||||
loss_history: list[float] = field(default_factory=list)
|
||||
val_loss_window_size: Optional[int] = 5 # todo: arg for this?
|
||||
|
||||
def track_loss_trend(self, mean_loss: float):
|
||||
if self.val_loss_window_size is None:
|
||||
return
|
||||
self.loss_history.append(mean_loss)
|
||||
|
||||
if len(self.loss_history) > ((self.val_loss_window_size * 2) + 1):
|
||||
dy = np.diff(self.loss_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss for {self.name} shows diverging. Check your loss/{self.name} graph.")
|
||||
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
default_batch_size: int,
|
||||
resolution: int,
|
||||
log_writer: SummaryWriter):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
|
||||
self.log_writer = log_writer
|
||||
log_writer: SummaryWriter,
|
||||
):
|
||||
self.validation_datasets = []
|
||||
self.resolution = resolution
|
||||
self.log_writer = log_writer
|
||||
|
||||
self.config = {
|
||||
'batch_size': default_batch_size,
|
||||
|
@ -54,20 +75,38 @@ class EveryDreamValidator:
|
|||
|
||||
'validate_training': True,
|
||||
'val_split_mode': 'automatic',
|
||||
'val_split_proportion': 0.15,
|
||||
'auto_split_proportion': 0.15,
|
||||
|
||||
'stabilize_training_loss': False,
|
||||
'stabilize_split_proportion': 0.15
|
||||
'stabilize_split_proportion': 0.15,
|
||||
|
||||
'use_relative_loss': False,
|
||||
|
||||
'extra_manual_datasets': {
|
||||
# name: path pairs
|
||||
# eg "santa suit": "/path/to/captioned_santa_suit_images", will be logged to tensorboard as "loss/santa suit"
|
||||
}
|
||||
}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
self.config.update(json.load(f))
|
||||
|
||||
self.train_overlapping_dataloader_loss_offset = None
|
||||
self.val_loss_offset = None
|
||||
if 'val_data_root' in self.config:
|
||||
logging.warning(f" * {Fore.YELLOW}using old name 'val_data_root' for 'manual_data_root' - please "
|
||||
f"update your validation config json{Style.RESET_ALL}")
|
||||
self.config.update({'manual_data_root': self.config['val_data_root']})
|
||||
|
||||
if self.config.get('val_split_mode') == 'manual':
|
||||
if 'manual_data_root' not in self.config:
|
||||
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
|
||||
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
||||
|
||||
if 'val_split_proportion' in self.config:
|
||||
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
|
||||
f"update your validation config json{Style.RESET_ALL}")
|
||||
self.config.update({'auto_split_proportion': self.config['val_split_proportion']})
|
||||
|
||||
|
||||
self.loss_val_history = []
|
||||
self.val_loss_window_size = 4 # todo: arg for this?
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
|
@ -81,6 +120,10 @@ class EveryDreamValidator:
|
|||
def seed(self):
|
||||
return self.config['seed']
|
||||
|
||||
@property
|
||||
def use_relative_loss(self):
|
||||
return self.config['use_relative_loss']
|
||||
|
||||
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Build the validation splits as requested by the config passed at init.
|
||||
|
@ -90,11 +133,20 @@ class EveryDreamValidator:
|
|||
"""
|
||||
with isolate_rng():
|
||||
random.seed(self.seed)
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||
|
||||
auto_dataset, remaining_train_items = self._build_automatic_validation_dataset_if_required(train_items, tokenizer)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
|
||||
train_overlapping_dataset = self._build_train_stabilizer_dataloader_if_required(
|
||||
remaining_train_items, tokenizer)
|
||||
|
||||
if auto_dataset is not None:
|
||||
self.validation_datasets.append(auto_dataset)
|
||||
if train_overlapping_dataset is not None:
|
||||
self.validation_datasets.append(train_overlapping_dataset)
|
||||
manual_splits = self._build_manual_validation_datasets(tokenizer)
|
||||
self.validation_datasets.extend(manual_splits)
|
||||
|
||||
return remaining_train_items
|
||||
|
||||
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||
|
@ -104,6 +156,7 @@ class EveryDreamValidator:
|
|||
return [epoch_length_steps-1]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||
# validation happens after training:
|
||||
|
@ -114,30 +167,15 @@ class EveryDreamValidator:
|
|||
def do_validation(self, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('stabilize-train',
|
||||
self.train_overlapping_dataloader,
|
||||
for i, dataset in enumerate(self.validation_datasets):
|
||||
mean_loss = self._calculate_validation_loss(dataset.name,
|
||||
dataset.dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.train_overlapping_dataloader_loss_offset is None:
|
||||
self.train_overlapping_dataloader_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
|
||||
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
|
||||
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
||||
scalar_value=mean_loss,
|
||||
global_step=global_step)
|
||||
if self.val_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('val',
|
||||
self.val_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_loss_offset is None:
|
||||
self.val_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/val",
|
||||
scalar_value=self.val_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
self.loss_val_history.append(mean_loss)
|
||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging")
|
||||
# todo: signal stop?
|
||||
dataset.track_loss_trend(mean_loss)
|
||||
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
|
@ -168,31 +206,35 @@ class EveryDreamValidator:
|
|||
return loss_validation_local
|
||||
|
||||
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
def _build_automatic_validation_dataset_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> tuple[Optional[ValidationDataset], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
||||
val_split_proportion = self.config['val_split_proportion']
|
||||
remaining_train_items = image_train_items
|
||||
if val_split_mode is None or val_split_mode == 'none':
|
||||
if val_split_mode is None or val_split_mode == 'none' or val_split_mode == 'manual':
|
||||
# manual is handled by _build_manual_validation_datasets
|
||||
return None, image_train_items
|
||||
elif val_split_mode == 'automatic':
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||
auto_split_proportion = self.config['auto_split_proportion']
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, auto_split_proportion, batch_size=self.batch_size)
|
||||
val_items = list(disable_multiplier_and_flip(val_items))
|
||||
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||
elif val_split_mode == 'manual':
|
||||
val_data_root = self.config.get('val_data_root', None)
|
||||
if val_data_root is None:
|
||||
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
|
||||
val_items = self._load_manual_val_split(val_data_root)
|
||||
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||
val_ed_batch = self._build_ed_batch(val_items, tokenizer=tokenizer, name='val')
|
||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return ValidationDataset(name='val', dataloader=val_dataloader), remaining_train_items
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return val_dataloader, remaining_train_items
|
||||
|
||||
def _build_manual_validation_datasets(self, tokenizer) -> list[ValidationDataset]:
|
||||
datasets = []
|
||||
for name, root in self.config.get('extra_manual_datasets', {}).items():
|
||||
items = self._load_manual_val_split(root)
|
||||
logging.info(f" * Loaded {len(items)} validation images for validation set '{name}' from {root}")
|
||||
ed_batch = self._build_ed_batch(items, tokenizer=tokenizer, name=name)
|
||||
dataloader = build_torch_dataloader(ed_batch, batch_size=self.batch_size)
|
||||
datasets.append(ValidationDataset(name=name, dataloader=dataloader))
|
||||
return datasets
|
||||
|
||||
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> Optional[torch.utils.data.DataLoader]:
|
||||
-> Optional[ValidationDataset]:
|
||||
stabilize_training_loss = self.config['stabilize_training_loss']
|
||||
if not stabilize_training_loss:
|
||||
return None
|
||||
|
@ -200,10 +242,9 @@ class EveryDreamValidator:
|
|||
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||
name='stabilize-train')
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, tokenizer=tokenizer, name='stabilize-train')
|
||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
return stabilize_dataloader
|
||||
return ValidationDataset(name='stabilize-train', dataloader=stabilize_dataloader, val_loss_window_size=None)
|
||||
|
||||
def _load_manual_val_split(self, val_data_root: str):
|
||||
args = Namespace(
|
||||
|
@ -216,7 +257,7 @@ class EveryDreamValidator:
|
|||
random.shuffle(val_items)
|
||||
return val_items
|
||||
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], tokenizer, name='val'):
|
||||
batch_size = self.batch_size
|
||||
seed = self.seed
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
|
|
|
@ -23,12 +23,12 @@ import yaml
|
|||
|
||||
import PIL
|
||||
import PIL.Image as Image
|
||||
import PIL.ImageOps as ImageOps
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
_RANDOM_TRIM = 0.04
|
||||
|
||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||
|
||||
OptionalImageCaption = typing.Optional['ImageCaption']
|
||||
|
||||
|
@ -36,7 +36,6 @@ class ImageCaption:
|
|||
"""
|
||||
Represents the various parts of an image caption
|
||||
"""
|
||||
|
||||
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
||||
"""
|
||||
:param main_prompt: The part of the caption which should always be included
|
||||
|
@ -49,7 +48,7 @@ class ImageCaption:
|
|||
self.__rating = rating
|
||||
self.__tags = tags
|
||||
self.__tag_weights = tag_weights
|
||||
self.__max_target_length = max_target_length
|
||||
self.__max_target_length = max_target_length or 2048
|
||||
self.__use_weights = use_weights
|
||||
if use_weights and len(tags) > len(tag_weights):
|
||||
self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights)))
|
||||
|
@ -67,7 +66,13 @@ class ImageCaption:
|
|||
:return: generated caption string
|
||||
"""
|
||||
if self.__tags:
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt)
|
||||
try:
|
||||
max_target_tag_length = self.__max_target_length - len(self.__main_prompt or 0)
|
||||
except Exception as e:
|
||||
print()
|
||||
logging.error(f"Error determining length for: {e} on {self.__main_prompt}")
|
||||
print()
|
||||
max_target_tag_length = 2048
|
||||
|
||||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
|
@ -113,137 +118,6 @@ class ImageCaption:
|
|||
random.Random(seed).shuffle(tags)
|
||||
return ", ".join(tags)
|
||||
|
||||
@staticmethod
|
||||
def parse(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses a string to get the caption.
|
||||
|
||||
:param string: String to parse.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
split_caption = list(map(str.strip, string.split(",")))
|
||||
main_prompt = split_caption[0]
|
||||
tags = split_caption[1:]
|
||||
tag_weights = [1.0] * len(tags)
|
||||
|
||||
return ImageCaption(main_prompt, 1.0, tags, tag_weights, DEFAULT_MAX_CAPTION_LENGTH, False)
|
||||
|
||||
@staticmethod
|
||||
def from_file_name(file_path: str) -> 'ImageCaption':
|
||||
"""
|
||||
Parses the file name to get the caption.
|
||||
|
||||
:param file_path: Path to the image file.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
(file_name, _) = os.path.splitext(os.path.basename(file_path))
|
||||
caption = file_name.split("_")[0]
|
||||
return ImageCaption.parse(caption)
|
||||
|
||||
@staticmethod
|
||||
def from_text_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a text file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: Path to the text file.
|
||||
:param default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8', mode='r') as caption_file:
|
||||
caption_text = caption_file.read()
|
||||
return ImageCaption.parse(caption_text)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Parses a yaml file to get the caption. Returns the default caption if
|
||||
the file does not exist or is invalid.
|
||||
|
||||
:param file_path: path to the yaml file
|
||||
:param default_caption: caption to return if the file does not exist or is invalid
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r") as stream:
|
||||
file_content = yaml.safe_load(stream)
|
||||
main_prompt = file_content.get("main_prompt", "")
|
||||
rating = file_content.get("rating", 1.0)
|
||||
unparsed_tags = file_content.get("tags", [])
|
||||
|
||||
max_caption_length = file_content.get("max_caption_length", DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
last_weight = None
|
||||
weights_differ = False
|
||||
for unparsed_tag in unparsed_tags:
|
||||
tag = unparsed_tag.get("tag", "").strip()
|
||||
if len(tag) == 0:
|
||||
continue
|
||||
|
||||
tags.append(tag)
|
||||
tag_weight = unparsed_tag.get("weight", 1.0)
|
||||
tag_weights.append(tag_weight)
|
||||
|
||||
if last_weight is not None and weights_differ is False:
|
||||
weights_differ = last_weight != tag_weight
|
||||
|
||||
last_weight = tag_weight
|
||||
|
||||
return ImageCaption(main_prompt, rating, tags, tag_weights, max_caption_length, weights_differ)
|
||||
except:
|
||||
logging.error(f" *** Error reading {file_path} to get caption")
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str, default_caption: OptionalImageCaption=None) -> OptionalImageCaption:
|
||||
"""
|
||||
Try to resolve a caption from a file path or return `default_caption`.
|
||||
|
||||
:string: The path to the file to parse.
|
||||
:default_caption: Optional `ImageCaption` to return if the file does not exist or is invalid.
|
||||
:return: `ImageCaption` object or `None`.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
(file_path_without_ext, ext) = os.path.splitext(file_path)
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ImageCaption.from_yaml_file(file_path, default_caption)
|
||||
|
||||
case ".txt" | ".caption":
|
||||
return ImageCaption.from_text_file(file_path, default_caption)
|
||||
|
||||
case '.jpg'| '.jpeg'| '.png'| '.bmp'| '.webp'| '.jfif':
|
||||
for ext in [".yaml", ".yml", ".txt", ".caption"]:
|
||||
file_path = file_path_without_ext + ext
|
||||
image_caption = ImageCaption.from_file(file_path)
|
||||
if image_caption is not None:
|
||||
return image_caption
|
||||
return ImageCaption.from_file_name(file_path)
|
||||
|
||||
case _:
|
||||
return default_caption
|
||||
else:
|
||||
return default_caption
|
||||
|
||||
@staticmethod
|
||||
def resolve(string: str) -> 'ImageCaption':
|
||||
"""
|
||||
Try to resolve a caption from a string. If the string is a file path,
|
||||
the caption will be read from the file, otherwise the string will be
|
||||
parsed as a caption.
|
||||
|
||||
:string: The string to resolve.
|
||||
:return: `ImageCaption` object.
|
||||
"""
|
||||
return ImageCaption.from_file(string, None) or ImageCaption.parse(string)
|
||||
|
||||
|
||||
class ImageTrainItem:
|
||||
"""
|
||||
image: PIL.Image
|
||||
|
@ -253,7 +127,7 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0):
|
||||
def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], pathname: str, flip_p=0.0, multiplier: float=1.0, cond_dropout=None):
|
||||
self.caption = caption
|
||||
self.aspects = aspects
|
||||
self.pathname = pathname
|
||||
|
@ -261,6 +135,7 @@ class ImageTrainItem:
|
|||
self.cropped_img = None
|
||||
self.runt_size = 0
|
||||
self.multiplier = multiplier
|
||||
self.cond_dropout = cond_dropout
|
||||
|
||||
self.image_size = None
|
||||
if image is None or len(image) == 0:
|
||||
|
@ -274,6 +149,22 @@ class ImageTrainItem:
|
|||
self.error = None
|
||||
self.__compute_target_width_height()
|
||||
|
||||
def load_image(self):
|
||||
try:
|
||||
image = PIL.Image.open(self.pathname).convert('RGB')
|
||||
image = self._try_transpose(image, print_error=False)
|
||||
except SyntaxError as e:
|
||||
pass
|
||||
return image
|
||||
|
||||
def _try_transpose(self, image, print_error=False):
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception as e:
|
||||
logging.warning(F"Error rotating image: {e} on {self.pathname}, image will be loaded as is, EXIF may be corrupt") if print_error else None
|
||||
pass
|
||||
return image
|
||||
|
||||
def hydrate(self, crop=False, save=False, crop_jitter=20):
|
||||
"""
|
||||
crop: hard center crop to 512x512
|
||||
|
@ -283,7 +174,7 @@ class ImageTrainItem:
|
|||
# print(self.pathname, self.image)
|
||||
try:
|
||||
# if not hasattr(self, 'image'):
|
||||
self.image = PIL.Image.open(self.pathname).convert('RGB')
|
||||
self.image = self.load_image()
|
||||
|
||||
width, height = self.image.size
|
||||
if crop:
|
||||
|
@ -353,7 +244,8 @@ class ImageTrainItem:
|
|||
def __compute_target_width_height(self):
|
||||
self.target_wh = None
|
||||
try:
|
||||
with Image.open(self.pathname) as image:
|
||||
with PIL.Image.open(self.pathname) as image:
|
||||
image = self._try_transpose(image, print_error=True).convert('RGB')
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
|
109
data/resolver.py
109
data/resolver.py
|
@ -4,6 +4,7 @@ import os
|
|||
import typing
|
||||
import zipfile
|
||||
import argparse
|
||||
from data.dataset import Dataset
|
||||
|
||||
import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
@ -27,16 +28,6 @@ class DataResolver:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def image_train_item(self, image_path: str, caption: ImageCaption, multiplier: float=1) -> ImageTrainItem:
|
||||
return ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=self.aspects,
|
||||
pathname=image_path,
|
||||
flip_p=self.flip_p,
|
||||
multiplier=multiplier
|
||||
)
|
||||
|
||||
class JSONResolver(DataResolver):
|
||||
def image_train_items(self, json_path: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -45,61 +36,7 @@ class JSONResolver(DataResolver):
|
|||
|
||||
:param json_path: The path to the JSON file.
|
||||
"""
|
||||
items = []
|
||||
with open(json_path, encoding='utf-8', mode='r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
for data in tqdm.tqdm(json_data):
|
||||
caption = JSONResolver.image_caption(data)
|
||||
if caption:
|
||||
image_value = JSONResolver.get_image_value(data)
|
||||
item = self.image_train_item(image_value, caption)
|
||||
if item:
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_image_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the image from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The image, or None if not found.
|
||||
"""
|
||||
image_value = json_data.get("image", None)
|
||||
if isinstance(image_value, str):
|
||||
image_value = image_value.strip()
|
||||
if os.path.exists(image_value):
|
||||
return image_value
|
||||
|
||||
@staticmethod
|
||||
def get_caption_value(json_data: dict) -> typing.Optional[str]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The caption, or None if not found.
|
||||
"""
|
||||
caption_value = json_data.get("caption", None)
|
||||
if isinstance(caption_value, str):
|
||||
return caption_value.strip()
|
||||
|
||||
@staticmethod
|
||||
def image_caption(json_data: dict) -> typing.Optional[ImageCaption]:
|
||||
"""
|
||||
Get the caption from the json data if possible.
|
||||
|
||||
:param json_data: The json data, a dict.
|
||||
:return: The `ImageCaption`, or None if not found.
|
||||
"""
|
||||
image_value = JSONResolver.get_image_value(json_data)
|
||||
caption_value = JSONResolver.get_caption_value(json_data)
|
||||
if image_value:
|
||||
if caption_value:
|
||||
return ImageCaption.resolve(caption_value)
|
||||
return ImageCaption.from_file(image_value)
|
||||
|
||||
return Dataset.from_json(json_path).image_train_items(self.aspects)
|
||||
|
||||
class DirectoryResolver(DataResolver):
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
|
@ -111,32 +48,7 @@ class DirectoryResolver(DataResolver):
|
|||
:param data_root: The root directory to recurse through
|
||||
"""
|
||||
DirectoryResolver.unzip_all(data_root)
|
||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||
items = []
|
||||
multipliers = {}
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
if current_dir not in multipliers:
|
||||
multiply_txt_path = os.path.join(current_dir, "multiply.txt")
|
||||
if os.path.exists(multiply_txt_path):
|
||||
try:
|
||||
with open(multiply_txt_path, 'r') as f:
|
||||
val = float(f.read().strip())
|
||||
multipliers[current_dir] = val
|
||||
logging.info(f" - multiply.txt in '{current_dir}' set to {val}")
|
||||
except Exception as e:
|
||||
logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}")
|
||||
multipliers[current_dir] = 1.0
|
||||
else:
|
||||
multipliers[current_dir] = 1.0
|
||||
|
||||
caption = ImageCaption.resolve(pathname)
|
||||
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
return Dataset.from_path(data_root).image_train_items(self.aspects)
|
||||
|
||||
@staticmethod
|
||||
def unzip_all(path):
|
||||
|
@ -150,21 +62,6 @@ class DirectoryResolver(DataResolver):
|
|||
except Exception as e:
|
||||
logging.error(f"Error unzipping files {e}")
|
||||
|
||||
@staticmethod
|
||||
def recurse_data_root(recurse_root):
|
||||
for f in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, f)
|
||||
|
||||
if os.path.isfile(current):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']:
|
||||
yield current
|
||||
|
||||
for d in os.listdir(recurse_root):
|
||||
current = os.path.join(recurse_root, d)
|
||||
if os.path.isdir(current):
|
||||
yield from DirectoryResolver.recurse_data_root(current)
|
||||
|
||||
def strategy(data_root: str) -> typing.Type[DataResolver]:
|
||||
"""
|
||||
Determine the strategy to use for resolving the data.
|
||||
|
|
|
@ -44,21 +44,25 @@ The value is defaulted at 0.04, which means 4% conditional dropout. You can set
|
|||
|
||||
## LR tweaking
|
||||
|
||||
Learning rate adjustment is a very important part of training. You can use the default settings, or you can tweak it. You should consider increasing this further if you increase your batch size further (10+) using [gradient checkpointing](#gradient_checkpointing).
|
||||
Learning rate adjustment is a very important part of training.
|
||||
|
||||
--lr 1.5e-6 ^
|
||||
--lr 1.0e-6 ^
|
||||
|
||||
By default, the learning rate is constant for the entire training session. However, if you want it to change by itself during training, you can use cosine.
|
||||
|
||||
General suggestion is 1e-6 for training SD1.5 at 512 resolution. For SD2.1 at 768, try a much lower value, such as 2e-7. [Validation](VALIDATION.md) can be helpful to tune learning rate.
|
||||
|
||||
## Clip skip
|
||||
|
||||
Aka "penultimate layer", this takes the output from the text encoder not from its last output layer, but layers before.
|
||||
Clip skip counts back from the last hidden layer of the text encoder output for use as the text embedding.
|
||||
|
||||
*Note: since EveryDream2 uses HuggingFace Diffusers library, the penultimate layer is already selected when training and running inference on SD2.x models.* This is defined in the text_encoder/config.json by the "num_hidden_layers" property of 23, which is penultimate out of the 24 layers and set by default in all diffusers SD2.x models.
|
||||
|
||||
--clip_skip 2 ^
|
||||
|
||||
A value of "2" is the canonical form of "penultimate layer" useed by various webuis, but 1 to 4 are accepted as well if you wish to experiment. Default is "0" which takes the "last hidden layer" or standard output of the text encoder as Stable Diffusion 1.X was originally designed. Training with this setting may necessititate or work better when also using the same setting in your webui/inference program.
|
||||
A value of "2" will count back one additional layer. For SD1.x, "2" would be "penultimate" layer as commonly referred to in the community. For SD2.x, it would be an *additional* layer back.
|
||||
|
||||
Values of 0 to 3 are valid and working. The number indicates how many extra layers to go "back" into the CLIP embedding output. 0 is the last layer and the default behavior. 1 is the layer before that, etc.
|
||||
*A value of "0" or "1" does nothing.*
|
||||
|
||||
### Cosine LR scheduler
|
||||
Cosine LR scheduler will "taper off" your learning rate over time. It will reach a peak value of your ```--lr``` value then taper off following a cosine curve. In other words, it allows you to set a high initial learning rate which lowers as training progresses. This *may* help speed up training without overfitting. If you wish to use this, I would set a slightly higher initial [learning rate](#lr-tweaking), maybe by 25-50% than you might use with a normal constant LR schedule.
|
||||
|
@ -125,6 +129,8 @@ To use a random seed, use -1:
|
|||
|
||||
Default behavior is to use a fixed seed of 555. The seed you set is fixed for all samples if you set a value other than -1. If you set a seed it is also incrememted for shuffling your training data every epoch (i.e. 555, 556, 557, etc). This makes training more deterministic. I suggest a fixed seed when you are trying A/B test tweaks to your general training setup, or when you want all your test samples to use the same seed.
|
||||
|
||||
Fixed seed should be using when performing A/B tests or hyperparameter sweeps. Random seed (-1) may be better if you are stopping and resuming training often so every restart is using random values for all of the various randomness sources used in training such as noising and data shuffling.
|
||||
|
||||
## Shuffle tags
|
||||
|
||||
For those training booru tagged models, you can use this arg to randomly (but deterministicly unless you use `--seed -1`) all the CSV tags in your captions
|
||||
|
@ -143,6 +149,8 @@ Based on [Nicholas Guttenberg's blog post](https://www.crosslabs.org//blog/diffu
|
|||
|
||||
Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/zero_freq_test_biggs.webp
|
||||
|
||||
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
|
||||
|
||||
# Stuff you probably don't need to mess with, but well here it is:
|
||||
|
||||
|
||||
|
@ -176,8 +184,8 @@ The files will be in ```logs/[your project folder]/ep[N]_batch_schedule.txt``` a
|
|||
|
||||
## clip_grad_norm
|
||||
|
||||
Clips the gradient normals to a maximum value. This is an experimental feature, you can read online about gradient clipping. Default is None (no clipping). This is typically used for gradient explosion problems, which are not an issue with EveryDream, but might be a fun thing to experiment with?
|
||||
Clips the gradient normals to a maximum value. Default is None (no clipping). This is typically used for gradient explosion problems, which are generally not an issue with EveryDream and the grad scaler in AMP mode keeps this from being too much of an issue, but it may be worth experimenting with.
|
||||
|
||||
--clip_grad_norm 1.0 ^
|
||||
--clip_grad_norm 100000.0 ^
|
||||
|
||||
This may drastically reduce training speed or have other undesirable effects. My brief toying was mostly unsuccessful. I would not recommend using this unless you know what you're doing or are bored, but you might discover something cool or interesting.
|
||||
Early indications seem to show high values such as 100000 may be helpful. Low values like 1.0 will drastically reduce training speed. Default is no gradient normal clipping. There are also other ways to deal with gradient explosion, such as increasing optimizer epsilon.
|
|
@ -1,30 +1,57 @@
|
|||
# RunPod
|
||||
The simplest approach for RunPod is to use the [EveryDream2 template](https://runpod.io/gsc?template=d1v63jb36t&ref=bbp9dh8x) to load a fully configured docker image.
|
||||
|
||||
## JupyterLab
|
||||
JupterLab will run on the pod by default. When opening JupyterLab `[Port 8888]` you will be prompted for a password. The default password is `EveryDream`. This can be changed by editing the pod's environment variables.
|
||||
**[Runpod Video Tutorial](https://www.youtube.com/watch?v=XAULP-4hsnA)**
|
||||
|
||||
Click here -> [EveryDream2 template](https://runpod.io/gsc?template=d1v63jb36t&ref=bbp9dh8x) to load a fully configured Docker image. Both Tensorboard and Jupyter lab are automatically started for you and you can simply click the links to connect.
|
||||
|
||||
If you wish to sign up for Runpod, please consider using this [referral link](https://runpod.io?ref=oko38cd0) to help support the project. 2% of your spend is given back in the form of credits back to the project and costs you nothing.
|
||||
|
||||
## SSH
|
||||
You can also [enable full SSH support](https://www.runpod.io/blog/how-to-achieve-true-ssh-on-runpod) by setting the PUBLIC_KEY environment variable
|
||||
|
||||
## Tensorboard
|
||||
Tensorboard will run automatically, and can be viewed on `[Port 6006]`
|
||||
# Vast.ai
|
||||
|
||||
# Vast
|
||||
The EveryDream2 docker image is also compatible with [vast.ai](https://console.vast.ai/).
|
||||
**[Vast.ai Video Tutorial](https://www.youtube.com/watch?v=PKQesb4om9I)**
|
||||
|
||||
The EveryDream2 Docker image makes running [vast.ai](https://console.vast.ai/) fairly easy.
|
||||
|
||||
`ghcr.io/victorchall/everydream2trainer:main`
|
||||
|
||||
## JupyterLab
|
||||
You can enable JupyterLab as part of the Vast.ai instance configuration. No JupyterLab password is required.
|
||||
Watch the video for a full setup example. Once the template is configured you can simply launch into it using any rented GPU instance by selecting the EveryDream2 docker template.
|
||||
|
||||
## Tensorboard
|
||||
You can specify tensorboard to run at startup as part of your instance config.
|
||||
|
||||
Open the tensorboard port via docker
|
||||
```tensorboard --logdir /workspace/EveryDream2trainer/logs --host 0.0.0.0 &```
|
||||
|
||||
![Config](vastai_config.jpg)
|
||||
Make sure to copy the IP:PORT to a new browser tab to connect to Tensorboard and Jupyter. You can see the ports by clicking the IP:PORT-RANGE on your instance once it is started.
|
||||
![Config](vastai_ports.jpg)
|
||||
# Once your instance is up and running
|
||||
Run the `Train_JupyterLab.ipynb` notebook
|
||||
The line with "6006/tcp" will be Tensorboard and the line with "8888/tcp" will be Jupyter. Click one to select, copy, then paste into a new browser tab.
|
||||
|
||||
## Password for Jupyter Lab (all platforms that use Docker)
|
||||
|
||||
The default password is `EveryDream`. This can be changed by editing environment variables or start parameters depending on what platform you use, or for local use, modify the docker run command.
|
||||
|
||||
# Instance concerns
|
||||
|
||||
## Bandwith
|
||||
|
||||
Make sure to select an instance with high bandwidth as you will need to wait to download your base model then later upload the finished checkpoint down to your own computer or up to Hugginface. 500mbps+ is good, closer to 1gbit is better. If you are uploading a lot of training data or wish to download your finished checkpoints directly to your computer you may also want to make sure the instance is closer to your physical location for improved transfer speed. You pay for rental while uploading and downloading, not just during training!
|
||||
|
||||
## GPU Selection
|
||||
|
||||
EveryDream2 requires a minimum 12GB Nvidia instance.
|
||||
|
||||
Hosts such as Vast and Runpod offer 3090 and 4090 instances which are good choices. The 3090 24 GB is a very good choice here, leaving plenty of room for running higher resolutions (768+) with a good batch size at a reasonable cost.
|
||||
|
||||
As of writing, the 4090 is now going to run quite a bit faster (~60-80%) than the 3090 due to Torch2 and Cuda 11.8 support, but costs more than the 3090. You will need to decide if it is cost effective when you go to rent something.
|
||||
|
||||
Common major cloud providers like AWS and GCP offer the T4 16GB and A10G 24GB which are suitable. A100 is generally overkill and not economical, and the 4090 may actually be faster unless you really need the 40/80 GB VRAM to run extremely high resolution training (1280+). 24GB cards can run 1024+ by enabling gradient checkpointing and using smaller batch sizes.
|
||||
|
||||
If you plan on running a significant amount of training over the course of many months, purchasing your own 3090 may be another option instead of renting at anything, assuming your electricity prices are not a major concern. However, renting may be a good entry point to see if the hobby interests you first.
|
||||
|
||||
I do not recommend V100 GPUs or any other older architectures (K80, Titan, etc). Many of them will not support FP16 natively and are simply very slow. Almost no consumer cards prior to 30xx series have enough VRAM.
|
||||
|
||||
## Shutdown
|
||||
|
||||
Make sure to delete the instance when you are done. Runpod and Vast use a trash icon for this. Just stopping the instance isn't enough, and you pay pay for storage or rental until you completely delete it.
|
||||
|
||||
# Other Platforms
|
||||
|
||||
The Docker container should enable running on any host that supports using Docker containers, including GCP or AWS, and potentially with lifecycle management services such as GKE and ECS.
|
||||
|
||||
Most people looking to use GCP or AWS will likely already understand how to manage instances, but as a warning, make sure you know how to manage the instance lifecycles so you don't end up with a surprise bill at the end of the month. Leaving an instance running all month can get expensive!
|
|
@ -0,0 +1,61 @@
|
|||
# Contribution guide
|
||||
|
||||
Thank you for your interest in contributing to EveryDream!
|
||||
|
||||
## Way to contribute without code
|
||||
|
||||
* Join Discord and help other users with advice and troubleshooting.
|
||||
|
||||
* Report bugs. Use the github template or report on discord in #help along with your logs.
|
||||
|
||||
* Documentation. Is something confusing or missing? Contibute an update to documentation.
|
||||
|
||||
* Spread the word. Share your experience with others. Tell your friends. Write a blog post. Make a video. Post on social media. Every little bit helps!
|
||||
|
||||
* Share that you used EveryDream2 to make your model. Have a popular model on civitai or other sites? Leave a mention that you use EveryDream2 to train it.
|
||||
|
||||
* Share your training settings. Did you find a good set of settings for a particular dataset? Share it with others.
|
||||
|
||||
* Run A/B experiments. Try different hyperparameters and share your results on socials or Discord.
|
||||
|
||||
## Contributor License Agreement
|
||||
|
||||
Please review the [CLA](EveryDream_CLA.txt) before issuing a PR. You will be asked on your first submission to post your agreement for any code changes.
|
||||
|
||||
This is not required for simple documentation changes.
|
||||
|
||||
## Contributing code
|
||||
|
||||
EveryDream 2 trainer is a fairly complex piece of machinery with many options, and supports several runtime environments such as local Windows/Linux via venv, Docker/WSL, and Google Colab. One of the primary challenges is to ensure nothing breaks across platforms. EveryDream has users across all these platforms, and indeed, it is a bit of a chore keeping everything working, but one of the primary goals of the project to keep it as accessible as possible whether you have a home PC to use or are renting a cloud instance.
|
||||
|
||||
The most important thing to do when contributing is to make sure to *run training* on your preferred platform with your changes. A short training session and confirming that the first sample image or two after a couple hundred steps or so will ensure *most* functionality is working even if you can't test every possible runtime environment or combination of arguments. While automated tests can help with identifying regression, there's no replacement for actually running the whole software package. A quick 10-15 minute test on a small dataset is sufficient!
|
||||
|
||||
**If you cannot test every possible platform, your contribution is still welcome.** Please note waht you can and did test in your PR and we can work together to ensure it works across platforms, either by analysis for trivial changes, or help with testing. Some changes are small isolated and may not require full testing across all platforms, but noting how you tested will help us ensure it works for everyone.
|
||||
|
||||
## Code style and linting
|
||||
|
||||
**WIP** (there's a .pylint at least...)
|
||||
|
||||
### Running tests
|
||||
|
||||
**WIP** There is a small suite of automated unit tests. **WIP**
|
||||
|
||||
## Documentation
|
||||
|
||||
Please update the appropriate document file in `/doc` for your changes. If you are adding a new feature, please add a new section for users to read in order to understand how to use it, or if it is a significant feature, add a new document and link it from the main README.md or from the appropriate existing document.
|
||||
|
||||
## A few questions to ask yourself before working on enhancements
|
||||
|
||||
There is no formal process for contributing to EveryDream, but please consider the following before submitting a PR:
|
||||
|
||||
* Consider if the change is general enough to be useful to others, or is more specific to your project. Changes should provide value to a broad audience. Sometimes specific project needs can be served by a script for your specific data instead of a change to the trainer behavior, for instance.
|
||||
|
||||
* Please be mindful of adding new primary CLI arguments. New args should provide significant value to weigh the lengthening of the arg list. The current list is already daunting, and the project is meant to remain at least *somewhat* accessible for a machine learning project. There may be ways to expose new functionality in other ways for advanced users without making primary CLI args more complex.
|
||||
|
||||
* It's best to bring up any changes to default behavior in the Discord first.
|
||||
|
||||
* If you add or update dependencies make sure to update the [Docker requirements](../docker/requirements.txt), [windows_setup.cmd](../windows_setup.cmd), and [Colab dependencies install cell](../Train_Colab.ipynb). Please note that in your PR what platforms you were able to test or ask for help on Discord.
|
||||
|
||||
* Please consider checking in on the Discord #help channel after release to spot possible bugs encountered by users after your PR is merged.
|
||||
|
||||
|
102
doc/DATA.md
102
doc/DATA.md
|
@ -1,36 +1,38 @@
|
|||
# Data organization
|
||||
# Selecting and preparing images
|
||||
|
||||
Since this trainer relies on having captions for your training images you will need to decide how you want deal with this.
|
||||
## Number of images
|
||||
|
||||
There are two currently supported methods to retrieve captions:
|
||||
You should probably start with less than 100 images, until you get a feel for training. When you are ready, ED2 supports up to hundreds of thousands of images.
|
||||
|
||||
1. Name the files with the caption. Underscore marks the end of the caption (ex. "john smith on a boat_999.jpg")
|
||||
2. Put your captions for each image in a .txt file with the same name as the image. All UTF-8 text is supported with no reserved or special case characters. (ex. 00001.jpg, 00001.txt)
|
||||
## Image size and quality
|
||||
ED2 supports `.jpg`, `.jpeg`, `.png`, `.bmp`, `.webp`, and `.jfif` image formats.
|
||||
|
||||
You will need to place all your images and captions into a folder. Inside that folder, you can use subfolders to organize data as you please. The trainer will recursively search for images and captions. It may be useful, for instance, to split each character into a subfolder, and have other subfolders for cityscapes, etc.
|
||||
Current recommendation is _at least_ 1 megapixel (ex 1024x1024, 1100x900, 1300x800, etc). That being said, technology continues to advance rapidly. ED2 has no problem handling 4K images, so it's up to you to pick the appropriate trade-off with disk and network costs.
|
||||
|
||||
When you train, you will use "--data_root" to point to the root folder of your data. All images in that folder and its subfolders will be used for training.
|
||||
Bare minimum recommended size is 512x512. Scaling images up is not a great idea though, though it may be tolerable as a very small percentage of your data set. If you only have 512x512 images, don't try to train at 768.
|
||||
|
||||
If you wish to boost training on a particular folder of images, put a "multiply.txt" in that folder with a whole number in it, such as 2. This will multiply the number of times images in that specific folder are used for training per epoch. This is useful if you have two characters you want to train, separated into separate folders, but one character has fewer images.
|
||||
Use high quality, in-focus, low-noise, images, capturing the concept(s) under training with high fidelity wherever possible.
|
||||
|
||||
# Data preparation
|
||||
## Cropping and Aspect Ratios
|
||||
|
||||
## Image size
|
||||
You can crop your images in an image editor __if it highlights the concept under training__, e.g. to get good close ups of things like faces, or to split images up that contain multiple characters.
|
||||
|
||||
The trainer will automatically fit your images to the best possible size. It is best to leave your images larger tham you may think for typical Stable Diffusion training. Even 4K images will be handled fine so just don't sweat it if you have large images. The only downside is they take a bit more disk space. There is almost no performance impact for having higher resolution images.
|
||||
**You do not need to crop to square**
|
||||
|
||||
Current recommendation is 1 megapixel (ex 1024x1024, 1100x900, 1300x800, etc) or larger, but thinking ahead to future technology advancements you may wish to keep them at even larger resolutions. Again, don't worry about the trainer squeezing or cropping, it will handle it!
|
||||
Aspect ratios between 4:1 and 1:4 are supported; the trainer will handle bucketing and resizing your images as needed.
|
||||
|
||||
Aspect ratios up to 4:1 or 1:4 are supported.
|
||||
It is ok to use a full shot of two characters in one image and also a cropped version of each character separately, but make sure every image is captioned appropriately for what is actually present in each image.
|
||||
|
||||
## Cropping
|
||||
## Caption Design
|
||||
|
||||
You can crop your images in an image editor *if you need, in order to get good close ups of things like faces, or to split images up that contain multiple characters.* As above, make sure **after** cropping your images are still fairly large. It is ok to use a full shot of two characters in one image and also a cropped version of each character separately, but make sure every image is captioned appropriately for what is actually present in each image.
|
||||
|
||||
## Captions
|
||||
A caption consists of a main prompt, followed by one or more comma-separated tags.
|
||||
|
||||
For most use cases, use a sane English sentence to describe the image. Try to put your character or main object name close to the start.
|
||||
|
||||
**If you are training on images of humans, there is little benefit in using "unique" names most of the time**. Don't worry so much about using a "rare" toking, or making up gibberish words. Just try generating a few images using your concept names, and make sure there are no serious conflicts.
|
||||
|
||||
**Use normal names for people and characters, such as "cloud strife" or "john david washington" instead of making up weird names like "cldstrf" or "jhndvdwshtn". There's no need for this and it just makes inference less natural and shifts a burden on the user to remember magic names.**
|
||||
|
||||
Those training anime models can use booru tags as well using other utilities to generate the captions.
|
||||
|
||||
### Styles
|
||||
|
@ -45,6 +47,68 @@ Include the surroundings and context in your captions. Ex. "cloud strife standi
|
|||
|
||||
Also consider some basic mention of pose. ex. "clouds strife sitting on a blue wooden bench in front of a concrete wall" or "barrett wallace holding his fist in front of his face with an angry look on his face, looking at the camera." Captions can capture value not only for the character's look, but also for the pose, the background scene, and the camera angle. You can be creative here, there is a lot of potential!
|
||||
|
||||
### Further reading
|
||||
|
||||
The [Data Balancing](BALANCING.md) guide has some more information on how to balance your data and what to consider for model preservation and mixing in ground truth data.
|
||||
# Constructing a dataset
|
||||
A dataset consists of image files coupled to captions and other configuration.
|
||||
|
||||
You are welcome to use any folder structure that makes sense for your project, but you should know that there are configuration tricks that rely on data being partitioned into folders and subfolders.
|
||||
|
||||
## Assigning captions
|
||||
### by Filename
|
||||
The simplest and least powerful way to caption images is by filename. The name of the file, without extension, and excluding any characters after an _ (underscore).
|
||||
|
||||
```
|
||||
a photo of ted bennet, sitting on a green armchair_1.jpg
|
||||
a photo of ted bennet, laying down_1.jpg
|
||||
a photo of ted bennet, laying down_2.jpg
|
||||
```
|
||||
### by Caption file
|
||||
If you are running in a Windows environment, you may not be able to fit your whole caption in the file name.
|
||||
|
||||
Instead you can create a text file with the same name as your image file, but with a `.txt` or `.caption` extension, and the content of the text file will be used as the caption, **ignoring the name of the file**.
|
||||
|
||||
### by Caption yaml
|
||||
You can capture a more complex caption structure by using a `.yaml` sidecar instead of a text file. Specifically you can assign weights to tags for [shuffling](SHUFFLING_TAGS.md).
|
||||
|
||||
The format for `.yaml` captions:
|
||||
```
|
||||
main_prompt: a photo of ted bennet
|
||||
tags:
|
||||
- "sitting on a green armchair" # The tag can be a simple string
|
||||
- tag: "wearing a tuxedo" # or it can be a tag string
|
||||
weight: 1.5 # optionally paired with a shuffle weight
|
||||
```
|
||||
|
||||
|
||||
### Assigning captions to entire folders
|
||||
As mentioned above, a caption is a main prompt accompanied by zero or more tags.
|
||||
Currently it is not possible for a caption to have more than one main tag, although this limitation may be removed in the future.
|
||||
|
||||
But, in some cases it may make sense to add the same tag to all images in a folder. You can place any configuration that should apply to all images in a local folder (ignoring anything in any subfolders) by adding a file called `local.yaml` to the folder. In this file you can, for example, add:
|
||||
```
|
||||
tags:
|
||||
- tag: "in the style of xyz"
|
||||
```
|
||||
And this tag will be appended to any tags specified at the image level.
|
||||
|
||||
If you want this tag, or any other configuration, to to apply to images in subfolders as well you can create a file called `global.yaml` and it will apply to all images in the local folder **and** to any images in any subfolders, recursively.
|
||||
|
||||
## Other image configuration
|
||||
In addition to captions, you can also specify the frequency with which each image should show up in training (`multiply`), or the frequency in which the trainer should be given a flipped version of the image (`flip_p`), or the frequency in which the caption should be dropped completely focusing the training on the image alone, ignoring the caption (`cond_dropout`).
|
||||
|
||||
For simple cases you can create a file called `multiply.txt`, `flip_p.txt`, and/or `cond_dropout.txt`, containing the single numeric value for that configuration parameter that should be applied to all images in the local folder.
|
||||
|
||||
Alternatively you can add these properties to any of the supported `.yaml` configuration files, image-level, `local.yaml`, and/or `global.yaml`
|
||||
|
||||
```
|
||||
main_prompt: a photo of ted bennet
|
||||
tags:
|
||||
- sitting on a green armchair
|
||||
multiply: 2
|
||||
flip_p: 0.5
|
||||
cond_droput: 0.05
|
||||
```
|
||||
|
||||
See [Advanced Tweaking](ADVANCED_TWEAKING.md) for more information on image flipping and conditional dropout.
|
||||
|
||||
The [Data Balancing](BALANCING.md) guide has some more information on how to balance your data using multipliers, and what to consider for model preservation and mixing in ground truth data.
|
||||
|
|
|
@ -4,13 +4,19 @@ Logs are important to review to track your training and make sure your settings
|
|||
|
||||
Everydream2 uses the Tensorboard library to log performance metrics. (more options coming!)
|
||||
|
||||
You should launch tensorboard while your training is running and watch along.
|
||||
You should launch tensorboard while your training is running and watch along. Open a separate command window, activate venv like you would for training, then run this:
|
||||
|
||||
tensorboard --logdir logs --samples_per_plugin images=100
|
||||
|
||||
You can leave Tensorboard running in the background as long as you wish. The `samples_per_plugin` arg will make sure Tensorboard gives finer control on the slider bar for looking through samples, but remember ALL samples are always in your logs, even if you don't see a particular expected sample step in Tensorboard.
|
||||
|
||||
VS Code can also launch Tensorboard by installing the extension, then CTRL-SHIFT-P, start typing "tensorboard" and select "Python: Launch Tensorboard", "select another folder", and select the "logs" folder under your EveryDream2trainer folder.
|
||||
|
||||
## Sample images
|
||||
|
||||
By default, the trainer produces sample images from `sample_prompts.txt` with a fixed seed every so many steps as defined by your `sample_steps` argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all (unless you launch tensorboard with the `--samples_per_plugin` argument as shown above), but the actual files are still stored in your logs as well for review.
|
||||
Sample images are generated periodically by the trainer to give visual feedback on training progress. **It's very important to keep an eye on your samples.** They are available in Tensorboard (and WandB if enabled), or in your logs folder.
|
||||
|
||||
By default, the trainer produces sample images from `sample_prompts.txt` with a fixed seed every so many steps as defined by your `sample_steps` argument. If you have a ton of them, the slider bar in tensorboard may not select them all (unless you launch tensorboard with the `--samples_per_plugin` argument as shown above), but the actual files are still stored in your logs as well for review.
|
||||
|
||||
Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing.
|
||||
|
||||
|
@ -58,12 +64,20 @@ Individual samples are defined under the `samples` key. Each sample can have a `
|
|||
|
||||
The lr curve is useful to make sure your learning rate curve looks as expected when using something other than constant. If you hand-tweak the decay steps you may cause issues with the curve, going down and then back up again for instance, in which case you may just wish to remove lr_decay_steps from your command to let the trainer set that for you.
|
||||
|
||||
Unet and Text encoder LR are logged separately because the text encoder can be set to ratio of the primary LR. See [Optimizer](OPTIMIZER.md) for more details. You can use the logs to confirm the behavior you expect is occurring.
|
||||
|
||||
## Loss
|
||||
|
||||
To be perfectly honest, loss on stable diffusion training just jumps around a lot. It's not a great metric to use to judge your training. It's better to look at the samples and see if they are improving.
|
||||
Standard loss metrics on Stable Diffusion training jumps around a lot in the scope of the fine tuning the community is doing. It's not a great metric to use to judge your training unless youa re shooting for a significant shift in the entire model (i.e. training on thousands, tens of thousands, or hundreds of thousands of images in an effort to make a broad shift in what the model generates).
|
||||
|
||||
For most users, it's better to look at the samples to subjectively judge if they are improving, or enable [Validation](VALIDATION.md). Validation adds the metric `val/loss` which show meaningful trends. Read the validation documentation for more information and hints on how to intrepet trends in `val/loss`.
|
||||
|
||||
## Grad scaler
|
||||
|
||||
`hyperparameters/grad scale` is logged for troubleshooting purposes. If the value trends down to a *negative power* (ex 5e-10), something is wrong with training, such as a wildly inappropriate setting or an issue with your installation. Otherwise, it bounces around, typically around Ne+3 to Ne+8 and is not much concern.
|
||||
|
||||
## Performance
|
||||
|
||||
Images per second will show you when you start a youtube video and your performance tanks. So, keep an eye on it if you start doing something else on your computer, particularly anything that uses GPU, even playing a video. Note that the initial performance has a ramp up time, once it gets going it should maintain as long as you don't do anything else that uses GPU. I have occasionally had issues with my GPU getting "locked" into "slow mode" after trying to play a video, so watch out for that.
|
||||
Images per second will show you when you start a youtube video and your performance tanks. So, keep an eye on it if you start doing something else on your computer, particularly anything that uses GPU, even playing a video.
|
||||
|
||||
Minutes per epoch is inverse, but you'll see it go up (slower, more minutes per epoch) when there are samples being generated that epoch. This is normal, but will give you an idea on how your sampling (``--sample_steps``) is affecting your training time. If you set the sample_steps low, you'll see your minutes per epoch spike more due to the delay involved in generating. It's still very important to generate samples, but you can weight the cost in speed vs the number of samples.
|
||||
Minutes per epoch is inverse, but you'll see it go up (slower, more minutes per epoch) when there are samples being generated that epoch. This is normal, but will give you an idea on how your sampling (`--sample_steps`) is affecting your training time. If you set the sample_steps low, you'll see your minutes per epoch spike more due to the delay involved in generating the samples. It's still very important to generate samples, but you can weight the cost in speed vs the number of samples.
|
|
@ -4,6 +4,8 @@ EveryDream is a *general case fine tuner*. It does not explicitly implement the
|
|||
|
||||
That means there is no "class" or "token" or "regularization images". It simply trains image and caption pairs, much more like the original training of Stable Diffusion, just at a much smaller "at home" scale.
|
||||
|
||||
For the sake of those experienced in machine learning, foregive me for stretching and demarking some terms, as this is voiced for the typical user coming from Dreambooth training with the vocabulary as commonly used there.
|
||||
|
||||
## What is "regularization" and "preservation"?
|
||||
|
||||
The Dreambooth technique uses the concept of adding *generated images from the model itself* to try to keep training from "veering too off course" and "damaging" the model while fine tuning a specific subject with just a handful of images. It served the purpose of "preserving" the integrity of the model. Early on in Dreambooth's lifecycle, people would train 5-20 images of their face, and use a few hundred or maybe a thousand "regularization" images along with the 5-20 training images of their new subject. Since then, many people want to scale to larger training, but more on that later...
|
||||
|
@ -18,7 +20,7 @@ I instead propose you replace images generated out of SD itself with original "g
|
|||
|
||||
### Enter ground truth
|
||||
|
||||
"Ground truth" for the sake of this document means real images not generated by AI. It's very easy to get publicly available ML data sets to serve this purpose and replace genarated "regularization" images with real photos or otherwise.
|
||||
"Ground truth" for the sake of this document means real images not generated by AI. It's very easy to get publicly available ML data sets to serve this purpose and replace generated "regularization" images with real photos or otherwise.
|
||||
|
||||
Sources include [FFHQ](https://github.com/NVlabs/ffhq-dataset), [Coco](https://cocodataset.org/#home), and [Laion](https://huggingface.co/datasets/laion/laion2B-en-aesthetic/tree/main). There is a simple scraper to search Laion parquet files in the tools repo, and the Laion dataset was used by [Compvis](https://github.com/CompVis/stable-diffusion#weights) and [Stability.ai](https://github.com/Stability-AI/stablediffusion#news) themselves to train the base model.
|
||||
|
||||
|
@ -30,7 +32,7 @@ Using ground truth images for the general purpose of "presevation" will, instead
|
|||
|
||||
"Preservation" images and "training" images have no special distinction in EveryDream. All images are treated the same and the trainer does not know the difference. It is all in how you use them.
|
||||
|
||||
Any preservation images still need a caption of some sort. Just "person" may be sufficient, afterall we're just trying to simulate Dreambooth for this example. This can be as easy as selecting all the images, F2 rename, type `person_` (with the underscore) and press enter. Windows will append (x) to every file to make sure the filenames are unique, and EveryDream interprets the underscore as the end of the caption when present in the filename, thus all the images will be read as having a caption of simply `person`, similar to how many people train Dreambooth.
|
||||
Any preservation images still need a caption of some sort. Just "person" may be sufficient, for the sake of this particular exmaple we're just trying to *simulate* Dreambooth. This can be as easy as selecting all the images, F2 rename, type `person_` (with the underscore) and press enter. Windows will append (x) to every file to make sure the filenames are unique, and EveryDream interprets the underscore as the end of the caption when present in the filename, thus all the images will be read as having a caption of simply `person`, similar to how many people train Dreambooth.
|
||||
|
||||
You could also generate "person" regularization images out of any Stable Diffusion inference application or download one of the premade regularization sets, *but I find this is less than ideal*. For small training, regularization or preservation is simply not needed. For longer term training you're much better off mixing in real "ground truth" images into your data instead of generated data. "Ground truth" meaning images not generated from an AI. Training back on generated data will reinforce the errors in the model, like extra limbs, weird fingers, watermarks, etc. Using real ground truth data can actually help improve the model.
|
||||
|
||||
|
|
|
@ -10,13 +10,19 @@ or in train.json
|
|||
|
||||
A default `optimizer.json` is supplied which you can modify.
|
||||
|
||||
This has expanded tweaking. This doc is incomplete, but there is information on the web on betas and weight decay setting you can search for.
|
||||
This extra json file allows expanded tweaking.
|
||||
|
||||
If you do not set optimizer_config, the defaults are `adamw8bit` with standard betas of `(0.9,0.999)`, weight decay `0.01`, and epsilon `1e-8`. The hyperparameters are originally from XavierXiao's Dreambooth code and based off Compvis Stable Diffusion code.
|
||||
If you do not set `optimizer_config` at all or set it to `null` in train.json, the defaults are `adamw8bit` with standard betas of `(0.9,0.999)`, weight decay `0.01`, and epsilon `1e-8`.
|
||||
|
||||
## Optimizers
|
||||
|
||||
In `optimizer.json` the `optimizer` value is the type of optimizer to use. Below are the supported optimizers.
|
||||
In `optimizer.json` the you can set independent optimizer settings for both the text encoder and unet. If you want shared settings, just fill out the `base` section and leave `text_encoder_overrides` properties null an they will be copied from the `base` section.
|
||||
|
||||
If you set the `text_encder_lr_scale` property, the text encoder will be trained with a multiple of the Unet learning rate if it the LR is being copied. If you explicitly set the text encoder LR, the `text_encder_lr_scale` is ignored. `text_encder_lr_scale` is likely to be deprecated in the future, but now is left for backwards compatibility.
|
||||
|
||||
For each of the `unet` and `text_encoder` sections, you can set the following properties:
|
||||
|
||||
`optimizer` value is the type of optimizer to use. Below are the supported optimizers.
|
||||
|
||||
* adamw
|
||||
|
||||
|
@ -28,14 +34,25 @@ Tim Dettmers / bitsandbytes AdamW 8bit optimizer. This is the default and recom
|
|||
|
||||
* lion
|
||||
|
||||
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. Unknown what hyperparameters will work well, but paper shows potentially quicker learning. *Highly experimental, but tested and works.*
|
||||
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. `Epsilon` is not used by lion.
|
||||
|
||||
Recommended settings for lion based on the paper are as follows:
|
||||
|
||||
"optimizer": "adamw8bit",
|
||||
"lr": 1e-7,
|
||||
"lr_scheduler": "constant",
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.10
|
||||
|
||||
The recommendations are based on "1/10th LR" but "10x the weight decay" compared to AdamW when training diffusion models. There are no known recommendations for the CLIP text encoder. Lion converges quickly, so take care with the learning rate, and even lower learning rates may be effective.
|
||||
|
||||
## Optimizer parameters
|
||||
|
||||
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
|
||||
|
||||
The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and Unet are trained with the same LR.
|
||||
The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json` or set the `text_encoder: lr` to its own value (not null). For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `0.5`, meaning the text encoder will be trained at half the learning rate of the unet.
|
||||
|
||||
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
|
||||
## General Beta, weight decay, epsilon, etc tuning
|
||||
|
||||
Note `lion` does not use epsilon.
|
||||
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak.
|
||||
|
|
|
@ -48,7 +48,6 @@ Double check your python version again after setup by running these two commands
|
|||
|
||||
Again, this should show 3.10.x
|
||||
|
||||
## Linux, Linux containers, WSL, Runpod, etc
|
||||
|
||||
TBD
|
||||
## Docker container
|
||||
|
||||
`docker run -it -p 8888:8888 -p 6006:6006 --gpus all -e JUPYTER_PASSWORD=test1234 -t ghcr.io/victorchall/everydream2trainer:nightly`
|
||||
|
|
|
@ -37,8 +37,7 @@ Resuming from a checkpoint, 50 epochs, 6 batch size, 3e-6 learning rate, constan
|
|||
--sample_steps 200 ^
|
||||
--lr 3e-6 ^
|
||||
--ckpt_every_n_minutes 10 ^
|
||||
--useadam8bit ^
|
||||
--ed1_mode
|
||||
--useadam8bit
|
||||
|
||||
Training from SD2 512 base model, 18 epochs, 4 batch size, 1.2e-6 learning rate, constant LR, generate samples evern 100 steps, 30 minute checkpoint interval, adam8bit, using imagesin the x:\mydata folder, training at resolution class of 640:
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ If you are training a huge dataset (20k+) then saving every 1 epoch may not be v
|
|||
|
||||
*A "last" checkpoint is always saved at the end of training.*
|
||||
|
||||
Diffusers copies of checkpoints are saved in your /logs/[project_name]/ckpts folder, and can be used to continue training if you want to pick up where you left off. CKPT files are saved in the root training folder by default. These folders can be changed. See [Advanced Tweaking](ATWEAKING.md) for more info.
|
||||
Diffusers copies of checkpoints are saved in your /logs/[project_name]/ckpts folder, and can be used to continue training if you want to pick up where you left off. CKPT files are saved in the root training folder by default. These folders can be changed. See [Advanced Tweaking](ADVANCED_TWEAKING.md) for more info.
|
||||
|
||||
### _Delay saving checkpoints_
|
||||
|
||||
|
@ -94,7 +94,7 @@ If you want to resume training from a previous run, you can do so by pointing to
|
|||
|
||||
## __Learning Rate__
|
||||
|
||||
The learning rate affects how much "training" is done on the model per training step. It is a very careful balance to select a value that will learn your data. See [Advanced Tweaking](ATWEAKING.md) for more info. Once you have started, the learning rate is a good first knob to turn as you move into more advanced tweaking.
|
||||
The learning rate affects how much "training" is done on the model per training step. It is a very careful balance to select a value that will learn your data. See [Advanced Tweaking](ADVANCED_TWEAKING.md) for more info. Once you have started, the learning rate is a good first knob to turn as you move into more advanced tweaking.
|
||||
|
||||
## __Batch Size__
|
||||
|
||||
|
@ -102,7 +102,7 @@ Batch size is also another "hyperparamter" of itself and there are tradeoffs. It
|
|||
|
||||
--batch_size 4 ^
|
||||
|
||||
While very small batch sizes can impact performance negatively, at some point larger sizes have little impact on overall speed as well, so shooting for the moon is not always advisable. Changing batch size may also impact what learning rate you use, with typically larger batch_size requiring a slightly higher learning rate. More info is provided in the [Advanced Tweaking](ATWEAKING.md) document.
|
||||
While very small batch sizes can impact performance negatively, at some point larger sizes have little impact on overall speed as well, so shooting for the moon is not always advisable. Changing batch size may also impact what learning rate you use, with typically larger batch_size requiring a slightly higher learning rate. More info is provided in the [Advanced Tweaking](ADVANCED_TWEAKING.md) document.
|
||||
|
||||
## __LR Scheduler__
|
||||
|
||||
|
@ -110,7 +110,7 @@ A learning rate scheduler can change your learning rate as training progresses.
|
|||
|
||||
At this time, ED2.0 supports constant or cosine scheduler.
|
||||
|
||||
The constant scheduler is the default and keeps your LR set to the value you set in the command line. That's really it for constant! I recommend sticking with it until you are comfortable with general training. More info in the [Advanced Tweaking](ATWEAKING.md) document.
|
||||
The constant scheduler is the default and keeps your LR set to the value you set in the command line. That's really it for constant! I recommend sticking with it until you are comfortable with general training. More info in the [Advanced Tweaking](ADVANCED_TWEAKING.md) document.
|
||||
|
||||
## __Sampling__
|
||||
|
||||
|
|
|
@ -91,10 +91,11 @@ The config file has the following options:
|
|||
#### Validation settings
|
||||
* `validate_training`: If `true`, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.
|
||||
* `val_split_mode`: Either `automatic` or `manual`, ignored if validate_training is false.
|
||||
* `automatic` val_split_mode picks a random subset of the training set (the number of items is controlled by `val_split_proportion`) and removes them from training to use as a validation set.
|
||||
* `manual` val_split_mode lets you provide your own folder of validation items (images and captions), specified using `val_data_root`.
|
||||
* `val_split_proportion`: For `automatic` val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.
|
||||
* `val_data_root`: For `manual` val_split_mode, the path to a folder containing validation items.
|
||||
* `automatic` val_split_mode picks a random subset of the training set (the number of items is controlled by `auto_split_proportion`) and removes them from training to use as a validation set.
|
||||
* `manual` val_split_mode lets you provide your own folder of validation items (images and captions), specified using `manual_data_root`.
|
||||
* `auto_split_proportion`: For `automatic` val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.
|
||||
* `manual_data_root`: For `manual` val_split_mode, the path to a folder containing validation items.
|
||||
* `extra_manual_datasets`: Dictionary specifying additional folders containing validation datasets - see "Extra manual datasets" below.
|
||||
|
||||
#### Train loss graph stabilization settings
|
||||
|
||||
|
@ -105,3 +106,29 @@ The config file has the following options:
|
|||
|
||||
* `every_n_epochs`: How often to run validation (1=every epoch).
|
||||
* `seed`: The seed to use when running validation passes, and also for picking subsets of the data to use with `automatic` val_split_mode and/or `stabilize_training_loss`.
|
||||
|
||||
#### Extra manual datasets
|
||||
|
||||
If you're building a model with multiple training subjects, you may want to specify additional validation datasets so you can check the progress of each part of your model separately. You can do this using the `extra_manual_datasets` property of the validation config .json file.
|
||||
|
||||
For example, suppose you're training a model for different dog breeds, and you're especially interested in how well it's training huskies and puggles. To do this, take some of your husky and puggle training data and put it into two separate folders, outside of the data root. For example, suppose you have 100 husky images and 100 puggle images, like this:
|
||||
```commandline
|
||||
/workspace/dogs-model-training/data_root/husky <- contains 100 images for training
|
||||
/workspace/dogs-model-training/data_root/puggle <- contains 100 images for training
|
||||
```
|
||||
Take about 15 images from each of the `husky` and `puggle` folders and put them in their own `validation` folder, outside of the `data_root`:
|
||||
```commandline
|
||||
/workspace/dogs-model-training/validation/husky <- contains 15 images for validation
|
||||
/workspace/dogs-model-training/validation/puggle <- contains 15 images for validation
|
||||
/workspace/dogs-model-training/data_root/husky <- contains the remaining 85 images for training
|
||||
/workspace/dogs-model-training/data_root/puggle <- contains the remaining 85 images for training
|
||||
```
|
||||
Then update your `validation_config.json` file by adding entries to `extra_manual_datasets` to point to these folders:
|
||||
```commandline
|
||||
"extra_manual_datasets": {
|
||||
"husky": "/workspace/dogs-model-training/validation/husky",
|
||||
"puggle": "/workspace/dogs-model-training/validation/puggle"
|
||||
}
|
||||
```
|
||||
When you run training, you'll now get two additional graphs, `loss/husky` and `loss/puggle` that show the progress for your `husky` and `puggle` training data.
|
||||
When you run training, you'll now get two additional graphs, `loss/husky` and `loss/puggle` that show the progress for your `husky` and `puggle` training data.
|
||||
|
|
25
doc/VRAM.md
25
doc/VRAM.md
|
@ -1,18 +1,26 @@
|
|||
# WTF is a CUDA out of memory error?
|
||||
|
||||
Training models is very intense on GPU resources, and CUDA out of memory error is quite common and to be expected as you figure out what you can get away with.
|
||||
Training models is very intense on GPU resources, and `CUDA out of memory error` is quite common and to be expected as you figure out what you can get away with inside the constraints of your GPU VRAM limit.
|
||||
|
||||
## Stuff you want on
|
||||
VRAM use depends on the model being trained (SD1.5 vs SD2.1 base), batch size, resolution, and a number of other settings.
|
||||
|
||||
Make sure you use the following settings in your json config or command line:
|
||||
## Stuff you want on for 12GB cards
|
||||
|
||||
`--amp` on CLI, or in json `"amp": true`
|
||||
AMP and AdamW8bit are now defaulted to on. These are VRAM efficient, produce high quality results, and should be on for all training.
|
||||
|
||||
AMP is a significant VRAM savings (and performance increase as well). It saves several GB and increases performance by 80-100% on Ampere class GPUs.
|
||||
Gradient checkpointing can still be turned on and off, and is not on by default. Turning it on will greatly reduce VRAM use at the expense of some performance. It is suggested to turn it on for any GPU with less than 16GB VRAM and *is definitely required for 12GB cards*.
|
||||
|
||||
`--useadam8bit` in CLI or in json `"useadam8bit": true`
|
||||
If you are using a customized `optimizer.json`, make sure `adamw8bit` is set as the optimizer. `AdamW` is significantly more VRAM intensive. `lion` is another option that is VRAM efficient, but is still fairly experimental in terms of understanding the best LR, betas, and weight decay settings. See [Optimizer docs](OPTIMIZER.md) for more information on advanced optimizer config if you want to try `lion` optimizer. *`adamw8bit` is the recommended and also the default.*
|
||||
|
||||
Tim Dettmers' AdamW 8bit optimizer (aka "bitsandbytes") is a significant VRAM savings (and performance increase as well). Highly recommended, even for high VRAM GPUs. It saves about 1.5GB and offers a performance boost.
|
||||
SD2.1 with the larger text encoder model may not train on 12GB cards. SD1.5 should work fine.
|
||||
|
||||
Batch size of 1 or 2 may be all you can use on 12GB.
|
||||
|
||||
Resolution of 512 may be all you can use on 12GB. You could try 576 or 640 at batch size 1.
|
||||
|
||||
Due to other things running on any given users' systems, precise advice cannot be given on what will run, though 12GB certainly can and does work.
|
||||
|
||||
Close all other programs and processes that are using GPU resources. Apps like Chrome and Discord can use many hundreds of megabytes of VRAM, and can add up quickly. You can also try disabling "hardware acceleration" in some apps which will shift the resources to CPU and system RAM, and save VRAM.
|
||||
|
||||
## I really want to train higher resolution, what do I do?
|
||||
|
||||
|
@ -20,6 +28,5 @@ Gradient checkpointing is pretty useful even on "high" VRAM GPUs like a 24GB 309
|
|||
|
||||
`--gradient_checkpointing` in CLI or in json `"gradient_checkpointing": true`
|
||||
|
||||
It is not suggested on 24GB GPUs at 704 or lower resolutoon. I would keep it off and reduce batch size instead.
|
||||
It is not suggested on 24GB GPUs at 704 or lower resolutoon. I would keep it off and reduce batch size instead to fit your training into VRAM.
|
||||
|
||||
Gradient checkpointing is also critical for lower VRAM GPUs like 16 GB T4 (Colab free tier) or 3060 12GB, 2080 Ti 11gb, etc. You most likely should keep it on for any GPU with less than 24GB and adjust batch size up or down to fit your VRAM.
|
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
Binary file not shown.
After Width: | Height: | Size: 167 KiB |
|
@ -1,7 +1,7 @@
|
|||
###################
|
||||
# Builder Stage
|
||||
FROM nvidia/cuda:11.7.1-devel-ubuntu22.04 AS builder
|
||||
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS builder
|
||||
LABEL org.opencontainers.image.licenses="AGPL-1.0-only"
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Don't write .pyc bytecode
|
||||
|
@ -23,21 +23,16 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
|||
ENV VIRTUAL_ENV=/workspace/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
|
||||
ADD requirements.txt /build
|
||||
ADD requirements-build.txt /build
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m venv ${VIRTUAL_ENV} && \
|
||||
pip install -U -I torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url "https://download.pytorch.org/whl/cu117" && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install --pre --no-deps xformers==0.0.17.dev451
|
||||
# In case of emergency, build xformers from scratch
|
||||
# export FORCE_CUDA=1 && export TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6" && export CUDA_VISIBLE_DEVICES=0 && \
|
||||
# pip install --no-deps git+https://github.com/facebookresearch/xformers.git@48a77cc#egg=xformers
|
||||
|
||||
pip install -U -I torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url "https://download.pytorch.org/whl/cu118" && \
|
||||
pip install -r requirements-build.txt && \
|
||||
pip install --no-deps xformers==0.0.18
|
||||
|
||||
###################
|
||||
# Runtime Stage
|
||||
FROM nvidia/cuda:11.7.1-runtime-ubuntu22.04 as runtime
|
||||
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 as runtime
|
||||
|
||||
# Use bash shell
|
||||
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
|
||||
|
@ -76,12 +71,17 @@ RUN echo "source ${VIRTUAL_ENV}/bin/activate" >> /root/.bashrc
|
|||
# Workaround for:
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/62
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/73
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda-11.7/targets/x86_64-linux/lib"
|
||||
RUN ln /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so.11.0 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so
|
||||
RUN pip install bitsandbytes==0.37.0
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
|
||||
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudart.so
|
||||
RUN ln /usr/local/cuda/targets/x86_64-linux/lib/libnvrtc.so.11.8.89 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvrtc.so
|
||||
|
||||
ADD requirements-runtime.txt /
|
||||
RUN pip install --no-cache-dir -r requirements-runtime.txt
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CACHEBUST=1
|
||||
RUN git clone https://github.com/victorchall/EveryDream2trainer
|
||||
|
||||
WORKDIR /workspace/EveryDream2trainer
|
||||
RUN python utils/get_yamls.py && \
|
||||
mkdir -p logs && mkdir -p input
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
diffusers[torch]>=0.14.0
|
||||
ninja
|
||||
numpy
|
||||
omegaconf==2.2.3
|
||||
protobuf==3.20.3
|
||||
pyre-extensions==0.0.23
|
||||
pytorch-lightning==1.9.2
|
||||
transformers==4.27.1
|
|
@ -0,0 +1,16 @@
|
|||
aiohttp==3.8.4
|
||||
bitsandbytes==0.37.2
|
||||
colorama==0.4.6
|
||||
compel~=1.1.3
|
||||
ftfy==6.1.1
|
||||
ipyevents
|
||||
ipywidgets
|
||||
jupyter-archive
|
||||
jupyterlab
|
||||
lion-pytorch
|
||||
piexif==1.1.3
|
||||
pyfakefs
|
||||
pynvml==11.5.0
|
||||
speedtest-cli
|
||||
tensorboard==2.12.0
|
||||
wandb
|
|
@ -1,19 +0,0 @@
|
|||
aiohttp==3.8.4
|
||||
colorama==0.4.6
|
||||
diffusers[torch]>=0.13.0
|
||||
ftfy==6.1.1
|
||||
ipyevents
|
||||
ipywidgets
|
||||
jupyter-archive
|
||||
jupyterlab
|
||||
ninja
|
||||
omegaconf==2.2.3
|
||||
piexif==1.1.3
|
||||
protobuf==3.20.3
|
||||
pynvml==11.5.0
|
||||
pyre-extensions==0.0.30
|
||||
pytorch-lightning==1.9.2
|
||||
tensorboard==2.11.0
|
||||
transformers==4.25.1
|
||||
triton>=2.0.0a2
|
||||
wandb
|
|
@ -13,11 +13,13 @@ then
|
|||
service ssh start
|
||||
fi
|
||||
|
||||
tensorboard --logdir /workspace/EveryDream2trainer/logs --host 0.0.0.0 &
|
||||
|
||||
# RunPod JupyterLab
|
||||
if [[ $JUPYTER_PASSWORD ]]
|
||||
then
|
||||
tensorboard --logdir /workspace/EveryDream2trainer/logs --host 0.0.0.0 &
|
||||
jupyter nbextension enable --py widgetsnbextension
|
||||
jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
|
||||
jupyter lab --allow-root --no-browser --port=8888 --ip=* --ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' --ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace/EveryDream2trainer
|
||||
else
|
||||
echo "Container Started"
|
||||
|
|
|
@ -1,17 +1,37 @@
|
|||
{
|
||||
"doc": {
|
||||
"base": "base optimizer configuration for unet and text encoder",
|
||||
"text_encoder_overrides": "text encoder config overrides",
|
||||
"text_encoder_lr_scale": "if LR not set on text encoder, sets the Lr to a multiple of the Base LR. for example, if base `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`.",
|
||||
"-----------------": "-----------------",
|
||||
"optimizer": "adamw, adamw8bit, lion",
|
||||
"optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains",
|
||||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||
"lr": "learning rate, if null will use CLI or main JSON config value",
|
||||
"lr_scheduler": "'constant' or 'cosine'",
|
||||
"lr_warmup_steps": "number of steps to warmup LR to target LR, if null will use CLI or default a value based on max epochs",
|
||||
"lr_decay_steps": "number of steps to decay LR to zero for cosine, if null will use CLI or default a value based on max epochs",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)",
|
||||
"text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`."
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
},
|
||||
"base": {
|
||||
"optimizer": "adamw8bit",
|
||||
"lr": 1e-6,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_decay_steps": null,
|
||||
"lr_warmup_steps": null,
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.010,
|
||||
"text_encoder_lr_scale": 1.0
|
||||
"weight_decay": 0.010
|
||||
},
|
||||
"text_encoder_overrides": {
|
||||
"optimizer": null,
|
||||
"lr": null,
|
||||
"lr_scheduler": null,
|
||||
"lr_decay_steps": null,
|
||||
"lr_warmup_steps": null,
|
||||
"betas": null,
|
||||
"epsilon": null,
|
||||
"weight_decay": null
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
"""
|
||||
Copyright [2022-2023] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
from colorama import Fore, Style
|
||||
import pprint
|
||||
|
||||
BETAS_DEFAULT = [0.9, 0.999]
|
||||
EPSILON_DEFAULT = 1e-8
|
||||
WEIGHT_DECAY_DEFAULT = 0.01
|
||||
LR_DEFAULT = 1e-6
|
||||
OPTIMIZER_TE_STATE_FILENAME = "optimizer_te.pt"
|
||||
OPTIMIZER_UNET_STATE_FILENAME = "optimizer_unet.pt"
|
||||
|
||||
class EveryDreamOptimizer():
|
||||
"""
|
||||
Wrapper to manage optimizers
|
||||
resume_ckpt_path: path to resume checkpoint, will try to load state (.pt) files if they exist
|
||||
optimizer_config: config for the optimizers
|
||||
text_encoder: text encoder model parameters
|
||||
unet: unet model parameters
|
||||
"""
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
|
||||
del optimizer_config["doc"]
|
||||
print(f"\n raw optimizer_config:")
|
||||
pprint.pprint(optimizer_config)
|
||||
self.epoch_len = epoch_len
|
||||
self.te_config, self.base_config = self.get_final_optimizer_configs(args, optimizer_config)
|
||||
print(f"final unet optimizer config:")
|
||||
pprint.pprint(self.base_config)
|
||||
print(f"final text encoder optimizer config:")
|
||||
pprint.pprint(self.te_config)
|
||||
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.text_encoder_params = text_encoder_params
|
||||
self.unet_params = unet_params
|
||||
|
||||
self.optimizers = []
|
||||
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, text_encoder_params, unet_params)
|
||||
self.optimizers.append(self.optimizer_te) if self.optimizer_te is not None else None
|
||||
self.optimizers.append(self.optimizer_unet) if self.optimizer_unet is not None else None
|
||||
|
||||
self.lr_schedulers = []
|
||||
schedulers = self.create_lr_schedulers(args, optimizer_config)
|
||||
self.lr_schedulers.extend(schedulers)
|
||||
print(self.lr_schedulers)
|
||||
|
||||
self.load(args.resume_ckpt)
|
||||
|
||||
self.scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
|
||||
logging.info(f" Grad scaler enabled: {self.scaler.is_enabled()} (amp mode)")
|
||||
|
||||
def step(self, loss, step, global_step):
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
if self.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
|
||||
for optimizer in self.optimizers:
|
||||
self.scaler.step(optimizer)
|
||||
|
||||
self.scaler.update()
|
||||
self._zero_grad(set_to_none=True)
|
||||
|
||||
for scheduler in self.lr_schedulers:
|
||||
scheduler.step()
|
||||
|
||||
self._update_grad_scaler(global_step)
|
||||
|
||||
def _zero_grad(self, set_to_none=False):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.zero_grad(set_to_none=set_to_none)
|
||||
|
||||
def get_scale(self):
|
||||
return self.scaler.get_scale()
|
||||
|
||||
def get_unet_lr(self):
|
||||
return self.optimizer_unet.param_groups[0]['lr'] if self.optimizer_unet is not None else 0
|
||||
|
||||
def get_textenc_lr(self):
|
||||
return self.optimizer_te.param_groups[0]['lr'] if self.optimizer_te is not None else 0
|
||||
|
||||
def save(self, ckpt_path: str):
|
||||
"""
|
||||
Saves the optimizer states to path
|
||||
"""
|
||||
self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME)) if self.optimizer_te is not None else None
|
||||
self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME)) if self.optimizer_unet is not None else None
|
||||
|
||||
def load(self, ckpt_path: str):
|
||||
"""
|
||||
Loads the optimizer states from path
|
||||
"""
|
||||
te_optimizer_state_path = os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME)
|
||||
unet_optimizer_state_path = os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME)
|
||||
if os.path.exists(te_optimizer_state_path) and self.optimizer_unet is not None:
|
||||
self._load_optimizer(self.optimizer_unet, te_optimizer_state_path)
|
||||
if os.path.exists(unet_optimizer_state_path) and self.optimizer_te is not None:
|
||||
self._load_optimizer(self.optimizer_te, unet_optimizer_state_path)
|
||||
|
||||
def create_optimizers(self, args, text_encoder_params, unet_params):
|
||||
"""
|
||||
creates optimizers from config and args for unet and text encoder
|
||||
returns (optimizer_te, optimizer_unet)
|
||||
"""
|
||||
|
||||
if args.disable_textenc_training:
|
||||
optimizer_te = None
|
||||
else:
|
||||
optimizer_te = self._create_optimizer(args, self.te_config, text_encoder_params)
|
||||
if args.disable_unet_training:
|
||||
optimizer_unet = None
|
||||
else:
|
||||
optimizer_unet = self._create_optimizer(args, self.base_config, unet_params)
|
||||
|
||||
return optimizer_te, optimizer_unet
|
||||
|
||||
def get_final_optimizer_configs(self, args, global_optimizer_config):
|
||||
"""
|
||||
defautls and overrides based on priority of 'primary cli args > base config > text encoder overrides'
|
||||
"""
|
||||
base_config = global_optimizer_config.get("base")
|
||||
te_config = global_optimizer_config.get("text_encoder_overrides")
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(self.epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
if args.lr_warmup_steps is None:
|
||||
args.lr_warmup_steps = int(args.lr_decay_steps / 50)
|
||||
|
||||
if args.lr is not None:
|
||||
base_config["lr"] = args.lr
|
||||
|
||||
base_config["optimizer"] = base_config.get("optimizer", None) or "adamw8bit"
|
||||
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", None) or args.lr_warmup_steps
|
||||
base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps
|
||||
base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler
|
||||
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", None) or args.lr_warmup_steps
|
||||
base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps
|
||||
base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler
|
||||
|
||||
te_config["lr"] = te_config.get("lr", None) or base_config["lr"]
|
||||
te_config["optimizer"] = te_config.get("optimizer", None) or base_config["optimizer"]
|
||||
te_config["lr_scheduler"] = te_config.get("lr_scheduler", None) or base_config["lr_scheduler"]
|
||||
te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", None) or base_config["lr_warmup_steps"]
|
||||
te_config["lr_decay_steps"] = te_config.get("lr_decay_steps", None) or base_config["lr_decay_steps"]
|
||||
te_config["weight_decay"] = te_config.get("weight_decay", None) or base_config["weight_decay"]
|
||||
te_config["betas"] = te_config.get("betas", None) or base_config["betas"]
|
||||
te_config["epsilon"] = te_config.get("epsilon", None) or base_config["epsilon"]
|
||||
|
||||
return te_config, base_config
|
||||
|
||||
def create_lr_schedulers(self, args, optimizer_config):
|
||||
unet_config = optimizer_config["base"]
|
||||
te_config = optimizer_config["text_encoder_overrides"]
|
||||
|
||||
ret_val = []
|
||||
|
||||
if self.optimizer_te is not None:
|
||||
lr_scheduler = get_scheduler(
|
||||
te_config.get("lr_scheduler", args.lr_scheduler),
|
||||
optimizer=self.optimizer_te,
|
||||
num_warmup_steps=te_config.get("lr_warmup_steps", None),
|
||||
num_training_steps=unet_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"]
|
||||
)
|
||||
ret_val.append(lr_scheduler)
|
||||
|
||||
if self.optimizer_unet is not None:
|
||||
unet_config = optimizer_config["base"]
|
||||
lr_scheduler = get_scheduler(
|
||||
unet_config["lr_scheduler"],
|
||||
optimizer=self.optimizer_unet,
|
||||
num_warmup_steps=int(unet_config["lr_warmup_steps"]),
|
||||
num_training_steps=int(unet_config["lr_decay_steps"]),
|
||||
)
|
||||
ret_val.append(lr_scheduler)
|
||||
return ret_val
|
||||
|
||||
def _update_grad_scaler(self, global_step):
|
||||
if global_step == 500:
|
||||
factor = 1.8
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(100)
|
||||
if global_step == 1000:
|
||||
factor = 1.6
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(200)
|
||||
if global_step == 2000:
|
||||
factor = 1.3
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(500)
|
||||
if global_step == 4000:
|
||||
factor = 1.15
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(2000)
|
||||
|
||||
@staticmethod
|
||||
def _save_optimizer(optimizer, path: str):
|
||||
"""
|
||||
Saves the optimizer state to specific path/filename
|
||||
"""
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
@staticmethod
|
||||
def _load_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state to an Optimizer object
|
||||
optimizer: torch.optim.Optimizer
|
||||
path: .pt file
|
||||
"""
|
||||
try:
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
logging.info(f" Loaded optimizer state from {path}")
|
||||
except Exception as e:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}**Failed to load optimizer state from {path}, optimizer state will not be loaded, \n * Exception: {e}{Style.RESET_ALL}")
|
||||
pass
|
||||
|
||||
def _create_optimizer(self, args, local_optimizer_config, parameters):
|
||||
betas = BETAS_DEFAULT
|
||||
epsilon = EPSILON_DEFAULT
|
||||
weight_decay = WEIGHT_DECAY_DEFAULT
|
||||
opt_class = None
|
||||
optimizer = None
|
||||
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if local_optimizer_config is not None:
|
||||
betas = local_optimizer_config["betas"] or betas
|
||||
epsilon = local_optimizer_config["epsilon"] or epsilon
|
||||
weight_decay = local_optimizer_config["weight_decay"] or weight_decay
|
||||
optimizer_name = local_optimizer_config["optimizer"] or "adamw8bit"
|
||||
curr_lr = local_optimizer_config.get("lr", curr_lr)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
||||
print(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
opt_class = Lion
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = torch.optim.AdamW
|
||||
else:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
|
||||
if not optimizer:
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
|
||||
return optimizer
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
|
@ -0,0 +1,19 @@
|
|||
diffusers[torch]>=0.14.0
|
||||
ninja
|
||||
numpy
|
||||
omegaconf==2.2.3
|
||||
protobuf==3.20.3
|
||||
pyre-extensions==0.0.23
|
||||
pytorch-lightning==1.9.2
|
||||
transformers==4.27.1
|
||||
aiohttp==3.8.4
|
||||
bitsandbytes==0.37.2
|
||||
colorama==0.4.6
|
||||
compel~=1.1.3
|
||||
ftfy==6.1.1
|
||||
lion-pytorch
|
||||
piexif==1.1.3
|
||||
pyfakefs
|
||||
pynvml==11.5.0
|
||||
tensorboard==2.12.0
|
||||
wandb
|
|
@ -58,13 +58,13 @@ class TestResolve(unittest.TestCase):
|
|||
|
||||
def test_directory_resolve_with_str(self):
|
||||
items = resolver.resolve(DATA_PATH, ARGS)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_paths = set(item.pathname for item in items)
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
captions = set(caption.get_caption() for caption in image_captions)
|
||||
|
||||
self.assertEqual(len(items), 3)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH])
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3'])
|
||||
self.assertEqual(image_paths, {IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH})
|
||||
self.assertEqual(captions, {'caption for test1', 'test2', 'test3'})
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 1)
|
||||
|
@ -75,7 +75,7 @@ class TestResolve(unittest.TestCase):
|
|||
'path': DATA_PATH,
|
||||
}
|
||||
|
||||
items = resolver.resolve(data_root_spec, ARGS)
|
||||
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -88,7 +88,7 @@ class TestResolve(unittest.TestCase):
|
|||
self.assertEqual(len(undersized_images), 1)
|
||||
|
||||
def test_json_resolve_with_str(self):
|
||||
items = resolver.resolve(JSON_ROOT_PATH, ARGS)
|
||||
items = sorted(resolver.resolve(JSON_ROOT_PATH, ARGS), key=lambda i: i.pathname)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
@ -124,14 +124,14 @@ class TestResolve(unittest.TestCase):
|
|||
JSON_ROOT_PATH,
|
||||
]
|
||||
|
||||
items = resolver.resolve(data_root_spec, ARGS)
|
||||
items = sorted(resolver.resolve(data_root_spec, ARGS), key=lambda i: i.pathname)
|
||||
image_paths = [item.pathname for item in items]
|
||||
image_captions = [item.caption for item in items]
|
||||
captions = [caption.get_caption() for caption in image_captions]
|
||||
|
||||
self.assertEqual(len(items), 6)
|
||||
self.assertEqual(image_paths, [IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2)
|
||||
self.assertEqual(captions, ['caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'])
|
||||
self.assertEqual(set(image_paths), set([IMAGE_1_PATH, IMAGE_2_PATH, IMAGE_3_PATH] * 2))
|
||||
self.assertEqual(set(captions), {'caption for test1', 'test2', 'test3', 'caption for test1', 'caption for test2', 'test3'})
|
||||
|
||||
undersized_images = list(filter(lambda i: i.is_undersized, items))
|
||||
self.assertEqual(len(undersized_images), 2)
|
|
@ -0,0 +1,380 @@
|
|||
import os
|
||||
from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH
|
||||
|
||||
from textwrap import dedent
|
||||
from pyfakefs.fake_filesystem_unittest import TestCase
|
||||
|
||||
class TestDataset(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
self.setUpPyfakefs()
|
||||
|
||||
def test_a_caption_is_generated_from_image_given_no_other_config(self):
|
||||
self.fs.create_file("image, tag1, tag2.jpg")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image, tag1, tag2.jpg": ImageConfig(main_prompts="image", tags=frozenset([Tag("tag1"), Tag("tag2")]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_several_image_formats_are_supported(self):
|
||||
self.fs.create_file("image.JPG")
|
||||
self.fs.create_file("image.jpeg")
|
||||
self.fs.create_file("image.png")
|
||||
self.fs.create_file("image.webp")
|
||||
self.fs.create_file("image.jfif")
|
||||
self.fs.create_file("image.bmp")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
common_cfg = ImageConfig(main_prompts="image")
|
||||
expected = {
|
||||
"./image.JPG": common_cfg,
|
||||
"./image.jpeg": common_cfg,
|
||||
"./image.png": common_cfg,
|
||||
"./image.webp": common_cfg,
|
||||
"./image.jfif": common_cfg,
|
||||
"./image.bmp": common_cfg,
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_captions_can_be_read_from_txt_or_caption_sidecar(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.txt", contents="an image, test, from .txt")
|
||||
self.fs.create_file("image_2.jpg")
|
||||
self.fs.create_file("image_2.caption", contents="an image, test, from .caption")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image_1.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .txt")])),
|
||||
"./image_2.jpg": ImageConfig(main_prompts="an image", tags=frozenset([Tag("test"), Tag("from .caption")]))
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_captions_and_options_can_be_read_from_yaml_sidecar(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml",
|
||||
contents=dedent("""
|
||||
multiply: 2
|
||||
cond_dropout: 0.05
|
||||
flip_p: 0.5
|
||||
caption: "A simple caption, from .yaml"
|
||||
"""))
|
||||
self.fs.create_file("image_2.jpg")
|
||||
self.fs.create_file("image_2.yml",
|
||||
contents=dedent("""
|
||||
flip_p: 0.0
|
||||
caption:
|
||||
main_prompt: A complex caption
|
||||
rating: 1.1
|
||||
max_caption_length: 1024
|
||||
tags:
|
||||
- tag: from .yml
|
||||
- tag: with weight
|
||||
weight: 0.5
|
||||
- tag: 1234.5
|
||||
"""))
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image_1.jpg": ImageConfig(
|
||||
multiply=2,
|
||||
cond_dropout=0.05,
|
||||
flip_p=0.5,
|
||||
main_prompts="A simple caption",
|
||||
tags= { Tag("from .yaml") }
|
||||
),
|
||||
"./image_2.jpg": ImageConfig(
|
||||
flip_p=0.0,
|
||||
rating=1.1,
|
||||
max_caption_length=1024,
|
||||
main_prompts="A complex caption",
|
||||
tags= { Tag("from .yml"), Tag("with weight", weight=0.5), Tag("1234.5") }
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_captions_are_read_from_filename_if_no_main_prompt(self):
|
||||
self.fs.create_file("filename main prompt, filename tag.jpg")
|
||||
self.fs.create_file("filename main prompt, filename tag.yaml",
|
||||
contents=dedent("""
|
||||
caption:
|
||||
tags:
|
||||
- tag: standalone yaml tag
|
||||
"""))
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./filename main prompt, filename tag.jpg": ImageConfig(
|
||||
main_prompts="filename main prompt",
|
||||
tags= [ Tag("filename tag"), Tag("standalone yaml tag") ]
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_multiple_prompts_and_tags_from_multiple_sidecars_are_supported(self):
|
||||
self.fs.create_file("image_1.jpg")
|
||||
self.fs.create_file("image_1.yaml", contents=dedent("""
|
||||
main_prompt:
|
||||
- unique prompt
|
||||
- dupe prompt
|
||||
tags:
|
||||
- from .yaml
|
||||
- dupe tag
|
||||
"""))
|
||||
self.fs.create_file("image_1.txt", contents="also unique prompt, from .txt, dupe tag")
|
||||
self.fs.create_file("image_1.caption", contents="dupe prompt, from .caption")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./image_1.jpg": ImageConfig(
|
||||
main_prompts={ "unique prompt", "also unique prompt", "dupe prompt" },
|
||||
tags={ Tag("from .yaml"), Tag("from .txt"), Tag("from .caption"), Tag("dupe tag") }
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_sidecars_can_also_be_attached_to_local_and_recursive_folders(self):
|
||||
self.fs.create_file("./global.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: global prompt
|
||||
tags:
|
||||
- global tag
|
||||
flip_p: 0.0
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./local.yaml",
|
||||
contents=dedent("""
|
||||
main_prompt: local prompt
|
||||
tags:
|
||||
- tag: local tag
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./arbitrary filename.png")
|
||||
self.fs.create_file("./sub/sub arbitrary filename.png")
|
||||
self.fs.create_file("./sub/sidecar.png")
|
||||
self.fs.create_file("./sub/sidecar.txt",
|
||||
contents="sidecar prompt, sidecar tag")
|
||||
|
||||
self.fs.create_file("./optfile/optfile.png")
|
||||
self.fs.create_file("./optfile/flip_p.txt",
|
||||
contents="0.1234")
|
||||
|
||||
self.fs.create_file("./sub/sub2/global.yaml",
|
||||
contents=dedent("""
|
||||
tags:
|
||||
- tag: sub global tag
|
||||
"""))
|
||||
self.fs.create_file("./sub/sub2/local.yaml",
|
||||
contents=dedent("""
|
||||
tags:
|
||||
- This tag wil not apply to any files
|
||||
"""))
|
||||
self.fs.create_file("./sub/sub2/sub3/xyz.png")
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = {
|
||||
"./arbitrary filename.png": ImageConfig(
|
||||
main_prompts={ 'global prompt', 'local prompt' },
|
||||
tags=[ Tag("global tag"), Tag("local tag") ],
|
||||
flip_p=0.0
|
||||
),
|
||||
"./sub/sub arbitrary filename.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags=[ Tag("global tag") ],
|
||||
flip_p=0.0
|
||||
),
|
||||
"./sub/sidecar.png": ImageConfig(
|
||||
main_prompts={ 'global prompt', 'sidecar prompt' },
|
||||
tags=[ Tag("global tag"), Tag("sidecar tag") ],
|
||||
flip_p=0.0
|
||||
),
|
||||
"./optfile/optfile.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags=[ Tag("global tag") ],
|
||||
flip_p=0.1234
|
||||
),
|
||||
"./sub/sub2/sub3/xyz.png": ImageConfig(
|
||||
main_prompts={ 'global prompt' },
|
||||
tags=[ Tag("global tag"), Tag("sub global tag") ],
|
||||
flip_p=0.0
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_can_load_dataset_from_json_manifest(self):
|
||||
self.fs.create_file("./stuff/image_1.jpg")
|
||||
self.fs.create_file("./stuff/default.caption", contents= "default caption")
|
||||
self.fs.create_file("./other/image_1.jpg")
|
||||
self.fs.create_file("./other/image_2.jpg")
|
||||
self.fs.create_file("./other/image_3.jpg")
|
||||
self.fs.create_file("./manifest.json", contents=dedent("""
|
||||
[
|
||||
{ "image": "./stuff/image_1.jpg", "caption": "./stuff/default.caption" },
|
||||
{ "image": "./other/image_1.jpg", "caption": "other caption" },
|
||||
{
|
||||
"image": "./other/image_2.jpg",
|
||||
"caption": {
|
||||
"main_prompt": "complex caption",
|
||||
"rating": 0.1,
|
||||
"max_caption_length": 1000,
|
||||
"tags": [
|
||||
{"tag": "including"},
|
||||
{"tag": "weighted tag", "weight": 999.9}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"image": "./other/image_3.jpg",
|
||||
"multiply": 2,
|
||||
"flip_p": 0.5,
|
||||
"cond_dropout": 0.01,
|
||||
"main_prompt": [
|
||||
"first caption",
|
||||
"second caption"
|
||||
]
|
||||
}
|
||||
]
|
||||
"""))
|
||||
|
||||
actual = Dataset.from_json("./manifest.json").image_configs
|
||||
expected = {
|
||||
"./stuff/image_1.jpg": ImageConfig( main_prompts={"default caption"} ),
|
||||
"./other/image_1.jpg": ImageConfig( main_prompts={"other caption"} ),
|
||||
"./other/image_2.jpg": ImageConfig(
|
||||
main_prompts={ "complex caption" },
|
||||
rating=0.1,
|
||||
max_caption_length=1000,
|
||||
tags={
|
||||
Tag("including"),
|
||||
Tag("weighted tag", 999.9)
|
||||
}
|
||||
),
|
||||
"./other/image_3.jpg": ImageConfig(
|
||||
main_prompts={ "first caption", "second caption" },
|
||||
multiply=2,
|
||||
flip_p=0.5,
|
||||
cond_dropout=0.01
|
||||
)
|
||||
}
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_original_tag_order_is_retained_in_dataset(self):
|
||||
def get_random_string(length):
|
||||
letters = string.ascii_lowercase
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
|
||||
import uuid
|
||||
tags=[str(uuid.uuid4()) for _ in range(10000)]
|
||||
caption='main_prompt,'+", ".join(tags)
|
||||
self.fs.create_file("image.png")
|
||||
self.fs.create_file("image.txt", contents=caption)
|
||||
|
||||
actual = Dataset.from_path(".").image_configs
|
||||
|
||||
expected = { "./image.png": ImageConfig( main_prompts="main_prompt", tags=map(Tag.parse, tags)) }
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
def test_tag_order_is_retained_in_train_item(self):
|
||||
dataset = Dataset({
|
||||
"1.jpg": ImageConfig(
|
||||
main_prompts="main_prompt",
|
||||
tags=[
|
||||
Tag("xyz"),
|
||||
Tag("abc"),
|
||||
Tag("ijk")
|
||||
])
|
||||
})
|
||||
|
||||
aspects = []
|
||||
actual = dataset.image_train_items(aspects)
|
||||
|
||||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0].pathname, os.path.abspath('1.jpg'))
|
||||
self.assertEqual(actual[0].caption.get_caption(), "main_prompt, xyz, abc, ijk")
|
||||
|
||||
def test_dataset_can_produce_train_items(self):
|
||||
self.fs.create_file("./sub/global.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: global prompt
|
||||
tags:
|
||||
- low prio global tag
|
||||
- tag: high prio global tag
|
||||
weight: 10.0
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./sub/nested/local.yaml",
|
||||
contents=dedent("""
|
||||
tags:
|
||||
- tag: local tag
|
||||
"""))
|
||||
|
||||
self.fs.create_file("./sub/sub.jpg")
|
||||
self.fs.create_file("./sub/sub.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: sub.jpg prompt
|
||||
tags:
|
||||
- sub.jpg tag
|
||||
- another tag
|
||||
- last tag
|
||||
rating: 1.1
|
||||
max_caption_length: 1024
|
||||
multiply: 2
|
||||
flip_p: 0.1
|
||||
cond_dropout: 0.01
|
||||
"""))
|
||||
self.fs.create_file("./sub/nested/nested.jpg")
|
||||
self.fs.create_file("./sub/nested/nested.yaml",
|
||||
contents=dedent("""\
|
||||
main_prompt: nested.jpg prompt
|
||||
tags:
|
||||
- tag: nested.jpg tag
|
||||
weight: 0.1
|
||||
"""))
|
||||
self.fs.create_file("./root.jpg")
|
||||
self.fs.create_file("./root.txt", contents="root.jpg prompt, root.jpg tag")
|
||||
|
||||
aspects = []
|
||||
dataset = Dataset.from_path(".")
|
||||
actual = dataset.image_train_items(aspects)
|
||||
|
||||
self.assertEqual(len(actual), 3)
|
||||
|
||||
|
||||
self.assertEqual(actual[0].pathname, os.path.abspath('root.jpg'))
|
||||
self.assertEqual(actual[0].multiplier, 1.0)
|
||||
self.assertEqual(actual[0].flip.p, 0.0)
|
||||
self.assertIsNone(actual[0].cond_dropout)
|
||||
self.assertEqual(actual[0].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[0].caption.get_caption(), "root.jpg prompt, root.jpg tag")
|
||||
self.assertFalse(actual[0].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
||||
|
||||
self.assertEqual(actual[1].pathname, os.path.abspath('sub/sub.jpg'))
|
||||
self.assertEqual(actual[1].multiplier, 2.0)
|
||||
self.assertEqual(actual[1].flip.p, 0.1)
|
||||
self.assertEqual(actual[1].cond_dropout, 0.01)
|
||||
self.assertEqual(actual[1].caption.rating(), 1.1)
|
||||
self.assertEqual(actual[1].caption.get_caption(), "sub.jpg prompt, high prio global tag, sub.jpg tag, another tag, last tag, low prio global tag")
|
||||
self.assertTrue(actual[1].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, 1024)
|
||||
|
||||
self.assertEqual(actual[2].pathname, os.path.abspath('sub/nested/nested.jpg'))
|
||||
self.assertEqual(actual[2].multiplier, 1.0)
|
||||
self.assertEqual(actual[2].flip.p, 0.0)
|
||||
self.assertIsNone(actual[2].cond_dropout)
|
||||
self.assertEqual(actual[2].caption.rating(), 1.0)
|
||||
self.assertEqual(actual[2].caption.get_caption(), "nested.jpg prompt, high prio global tag, local tag, low prio global tag, nested.jpg tag")
|
||||
self.assertTrue(actual[2].caption._ImageCaption__use_weights)
|
||||
self.assertEqual(actual[2].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
|
@ -33,39 +33,3 @@ class TestImageCaption(unittest.TestCase):
|
|||
|
||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
def test_parse(self):
|
||||
caption = ImageCaption.parse("hello world, one, two, three")
|
||||
|
||||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
def test_from_file_name(self):
|
||||
caption = ImageCaption.from_file_name("foo bar_1_2_3.jpg")
|
||||
self.assertEqual(caption.get_caption(), "foo bar")
|
||||
|
||||
def test_from_text_file(self):
|
||||
caption = ImageCaption.from_text_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
def test_from_file(self):
|
||||
caption = ImageCaption.from_file("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.from_file("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
def test_resolve(self):
|
||||
caption = ImageCaption.resolve("test/data/test1.txt")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test_caption.caption")
|
||||
self.assertEqual(caption.get_caption(), "caption for test2")
|
||||
|
||||
caption = ImageCaption.resolve("hello world")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test1.jpg")
|
||||
self.assertEqual(caption.get_caption(), "caption for test1")
|
||||
|
||||
caption = ImageCaption.resolve("test/data/test2.jpg")
|
||||
self.assertEqual(caption.get_caption(), "test2")
|
|
@ -1,11 +1,11 @@
|
|||
{
|
||||
"amp": true,
|
||||
"batch_size": 10,
|
||||
"ckpt_every_n_minutes": null,
|
||||
"clip_grad_norm": null,
|
||||
"clip_skip": 0,
|
||||
"cond_dropout": 0.04,
|
||||
"data_root": "X:\\my_project_data\\project_abc",
|
||||
"disable_amp": false,
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
"flip_p": 0.0,
|
||||
|
@ -35,7 +35,7 @@
|
|||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"validation_config": null,
|
||||
"validation_config": "validation_default.json",
|
||||
"wandb": false,
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
|
|
330
train.py
330
train.py
|
@ -27,10 +27,9 @@ import gc
|
|||
import random
|
||||
import traceback
|
||||
import shutil
|
||||
import importlib
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from colorama import Fore, Style
|
||||
import numpy as np
|
||||
|
@ -38,6 +37,7 @@ import itertools
|
|||
import torch
|
||||
import datetime
|
||||
import json
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, \
|
||||
DPMSolverMultistepScheduler
|
||||
|
@ -49,6 +49,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||
from accelerate.utils import set_seed
|
||||
|
||||
import wandb
|
||||
import webbrowser
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
||||
|
@ -58,6 +59,8 @@ from data.image_train_item import ImageTrainItem
|
|||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.isolate_rng import isolate_rng
|
||||
from utils.check_git import check_git
|
||||
from optimizer.optimizers import EveryDreamOptimizer
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
|
@ -120,29 +123,26 @@ def setup_local_logger(args):
|
|||
format="%(asctime)s %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" not in msg.getMessage())
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
|
||||
#from PIL import Image
|
||||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
# def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
# """
|
||||
# Saves the optimizer state
|
||||
# """
|
||||
# torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Saves the optimizer state
|
||||
"""
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state
|
||||
"""
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
# def load_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
# """
|
||||
# Loads the optimizer state
|
||||
# """
|
||||
# optimizer.load_state_dict(torch.load(path))
|
||||
|
||||
def get_gpu_memory(nvsmi):
|
||||
"""
|
||||
|
@ -178,12 +178,12 @@ def set_args_12gb(args):
|
|||
if not args.gradient_checkpointing:
|
||||
logging.info(" - Overiding gradient checkpointing to True")
|
||||
args.gradient_checkpointing = True
|
||||
if args.batch_size != 1:
|
||||
logging.info(" - Overiding batch size to 1")
|
||||
args.batch_size = 1
|
||||
if args.batch_size > 2:
|
||||
logging.info(" - Overiding batch size to max 2")
|
||||
args.batch_size = 2
|
||||
args.grad_accum = 1
|
||||
if args.resolution > 512:
|
||||
logging.info(" - Overiding resolution to 512")
|
||||
logging.info(" - Overiding resolution to max 512")
|
||||
args.resolution = 512
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
|
@ -214,6 +214,12 @@ def setup_args(args):
|
|||
Sets defaults for missing args (possible if missing from json config)
|
||||
Forces some args to be set based on others for compatibility reasons
|
||||
"""
|
||||
if args.disable_amp:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}")
|
||||
args.amp = False
|
||||
else:
|
||||
args.amp = True
|
||||
|
||||
if args.disable_unet_training and args.disable_textenc_training:
|
||||
raise ValueError("Both unet and textenc are disabled, nothing to train")
|
||||
|
||||
|
@ -255,11 +261,6 @@ def setup_args(args):
|
|||
|
||||
total_batch_size = args.batch_size * args.grad_accum
|
||||
|
||||
if args.scale_lr is not None and args.scale_lr:
|
||||
tmp_lr = args.lr
|
||||
args.lr = args.lr * (total_batch_size**0.55)
|
||||
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
|
||||
|
||||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
|
@ -272,55 +273,56 @@ def setup_args(args):
|
|||
|
||||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 500:
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000:
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 2000:
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 4000:
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem]) -> None:
|
||||
for item in items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {item.error}")
|
||||
|
||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem], batch_size) -> None:
|
||||
undersized_items = [item for item in items if item.is_undersized]
|
||||
|
||||
if len(undersized_items) > 0:
|
||||
underized_log_path = os.path.join(log_folder, "undersized_images.txt")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
|
||||
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
|
||||
with open(underized_log_path, "w") as undersized_images_file:
|
||||
with open(underized_log_path, "w", encoding='utf-8') as undersized_images_file:
|
||||
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
|
||||
for undersized_item in undersized_items:
|
||||
message = f" *** {undersized_item.pathname} with size: {undersized_item.image_size} is smaller than target size: {undersized_item.target_wh}\n"
|
||||
undersized_images_file.write(message)
|
||||
|
||||
def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list[ImageTrainItem]:
|
||||
|
||||
# warn on underfilled aspect ratio buckets
|
||||
|
||||
# Intuition: if there are too few images to fill a batch, duplicates will be appended.
|
||||
# this is not a problem for large image counts but can seriously distort training if there
|
||||
# are just a handful of images for a given aspect ratio.
|
||||
|
||||
# at a dupe ratio of 0.5, all images in this bucket have effective multiplier 1.5,
|
||||
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
|
||||
warn_bucket_dupe_ratio = 0.5
|
||||
|
||||
ar_buckets = set([tuple(i.target_wh) for i in items])
|
||||
for ar_bucket in ar_buckets:
|
||||
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
|
||||
runt_size = batch_size - (count % batch_size)
|
||||
bucket_dupe_ratio = runt_size / count
|
||||
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
|
||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
|
||||
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
|
||||
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
|
||||
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
|
||||
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
|
||||
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
|
||||
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
|
||||
|
||||
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
logging.info(" Preloading images...")
|
||||
|
||||
resolved_items = resolver.resolve(args.data_root, args)
|
||||
report_image_train_item_problems(log_folder, resolved_items)
|
||||
image_paths = set(map(lambda item: item.pathname, resolved_items))
|
||||
|
||||
# Remove erroneous items
|
||||
for item in resolved_items:
|
||||
if item.error is not None:
|
||||
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{item.pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
|
||||
logging.error(f" *** exception: {item.error}")
|
||||
image_train_items = [item for item in resolved_items if item.error is None]
|
||||
print (f" * Found {len(image_paths)} files in '{args.data_root}'")
|
||||
|
||||
|
@ -357,11 +359,8 @@ def main(args):
|
|||
"""
|
||||
log_time = setup_local_logger(args)
|
||||
args = setup_args(args)
|
||||
|
||||
if args.notebook:
|
||||
from tqdm.notebook import tqdm
|
||||
else:
|
||||
from tqdm.auto import tqdm
|
||||
print(f" Args:")
|
||||
pprint.pprint(vars(args))
|
||||
|
||||
if args.seed == -1:
|
||||
args.seed = random.randint(0, 2**30)
|
||||
|
@ -383,7 +382,7 @@ def main(args):
|
|||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir, yaml_name, save_full_precision=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -421,13 +420,12 @@ def main(args):
|
|||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
# optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
# if self.save_optimizer_flag:
|
||||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
if save_optimizer_flag:
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
ed_optimizer.save(save_path)
|
||||
|
||||
optimizer_state_path = None
|
||||
try:
|
||||
|
||||
# check for a local file
|
||||
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
||||
if os.path.exists(hf_cache_path) or os.path.exists(args.resume_ckpt):
|
||||
|
@ -435,6 +433,10 @@ def main(args):
|
|||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||
|
||||
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
|
||||
if not os.path.exists(optimizer_state_path):
|
||||
optimizer_state_path = None
|
||||
else:
|
||||
# try to download from HF using resume_ckpt as a repo id
|
||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||
|
@ -468,8 +470,11 @@ def main(args):
|
|||
logging.warning("failed to load xformers, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
pass
|
||||
elif (not args.amp and is_sd1attn):
|
||||
logging.info("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
|
||||
unet.set_attention_slice("auto")
|
||||
else:
|
||||
logging.info("xformers disabled, using attention slicing instead")
|
||||
logging.info("xformers disabled via arg, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
|
@ -485,94 +490,25 @@ def main(args):
|
|||
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
|
||||
optimizer_config = json.load(f)
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name,
|
||||
sync_tensorboard=True,
|
||||
dir=args.logdir,
|
||||
config={"main":args, "optimizer":optimizer_config},
|
||||
if args.wandb:
|
||||
wandb.tensorboard.patch(root_logdir=log_folder, pytorch=False, tensorboard_x=False, save=False)
|
||||
wandb_run = wandb.init(
|
||||
project=args.project_name,
|
||||
config={"main_cfg": vars(args), "optimizer_cfg": optimizer_config},
|
||||
name=args.run_name,
|
||||
)
|
||||
try:
|
||||
if webbrowser.get():
|
||||
webbrowser.open(wandb_run.url, new=2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=10,
|
||||
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
|
||||
flush_secs=20,
|
||||
comment=args.run_name if args.run_name is not None else log_time,
|
||||
)
|
||||
|
||||
betas = [0.9, 0.999]
|
||||
epsilon = 1e-8
|
||||
weight_decay = 0.01
|
||||
opt_class = None
|
||||
optimizer = None
|
||||
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if optimizer_config is not None:
|
||||
betas = optimizer_config["betas"]
|
||||
epsilon = optimizer_config["epsilon"]
|
||||
weight_decay = optimizer_config["weight_decay"]
|
||||
optimizer_name = optimizer_config["optimizer"]
|
||||
curr_lr = optimizer_config.get("lr", curr_lr)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
||||
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
|
||||
if text_encoder_lr_scale != 1.0:
|
||||
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
if text_encoder_lr_scale != 1:
|
||||
logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the "
|
||||
f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = [{'params': unet.parameters()},
|
||||
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
opt_class = Lion
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name in ["adamw"]:
|
||||
opt_class = torch.optim.AdamW
|
||||
else:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
|
||||
if not optimizer:
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
image_train_items = resolve_image_train_items(args)
|
||||
|
||||
validator = None
|
||||
if args.validation_config is not None:
|
||||
|
@ -584,6 +520,8 @@ def main(args):
|
|||
# the validation dataset may need to steal some items from image_train_items
|
||||
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
|
||||
|
||||
report_image_train_item_problems(log_folder, image_train_items, batch_size=args.batch_size)
|
||||
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
|
@ -605,17 +543,7 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
|
||||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
@ -624,7 +552,9 @@ def main(args):
|
|||
config_file_path=args.sample_prompts,
|
||||
batch_size=max(1,args.batch_size//2),
|
||||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers)
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2)
|
||||
)
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
@ -655,7 +585,7 @@ def main(args):
|
|||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, args.save_full_precision)
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
else:
|
||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||
|
@ -687,15 +617,6 @@ def main(args):
|
|||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
epoch_times = []
|
||||
|
@ -753,6 +674,7 @@ def main(args):
|
|||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
return model_pred, target
|
||||
|
@ -818,20 +740,7 @@ def main(args):
|
|||
loss_scale = batch["runt_size"] / args.batch_size
|
||||
loss = loss * loss_scale
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
lr_scheduler.step()
|
||||
ed_optimizer.step(loss, step, global_step)
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
||||
|
@ -845,23 +754,23 @@ def main(args):
|
|||
loss_epoch.append(loss_step)
|
||||
|
||||
if (global_step + 1) % args.log_step == 0:
|
||||
curr_lr = lr_scheduler.get_last_lr()[0]
|
||||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||
lr_unet = ed_optimizer.get_unet_lr()
|
||||
lr_textenc = ed_optimizer.get_textenc_lr()
|
||||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
else:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
|
||||
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
|
||||
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=lr_unet, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
images_per_sec_log_step = []
|
||||
if args.amp:
|
||||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
|
||||
|
||||
logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -877,16 +786,15 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
@ -909,7 +817,7 @@ def main(args):
|
|||
# end of training
|
||||
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/last-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -919,7 +827,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
@ -928,8 +836,8 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||
supported_precisions = ['fp16', 'fp32']
|
||||
check_git()
|
||||
supported_resolutions = aspects.get_supported_resolutions()
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, argv = argparser.parse_known_args()
|
||||
|
@ -944,13 +852,14 @@ if __name__ == "__main__":
|
|||
print("No config file specified, using command line args")
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=False, help="Enables automatic mixed precision compute, recommended on")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4])
|
||||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
|
@ -966,7 +875,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
|
@ -979,7 +887,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--save_ckpts_from_n_epochs", type=int, default=0, help="Only saves checkpoints starting an N epochs, def: 0 (disabled)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
|
@ -992,6 +899,5 @@ if __name__ == "__main__":
|
|||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
print(f" Args:")
|
||||
pprint.pprint(vars(args))
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"batch_size": 6,
|
||||
"ckpt_every_n_minutes": null,
|
||||
"clip_grad_norm": null,
|
||||
"clip_skip": 2,
|
||||
"clip_skip": 0,
|
||||
"cond_dropout": 0.04,
|
||||
"data_root": "X:\\my_project_data\\project_abc",
|
||||
"disable_textenc_training": false,
|
||||
|
@ -21,20 +21,23 @@
|
|||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "project_abc_sd21",
|
||||
"resolution": 768,
|
||||
"resume_ckpt": "v2-1_768-nonema-pruned",
|
||||
"sample_prompts": "sample_prompts.txt",
|
||||
"sample_steps": 300,
|
||||
"save_ckpt_dir": null,
|
||||
"save_ckpts_from_n_epochs": 0,
|
||||
"save_every_n_epochs": 20,
|
||||
"save_optimizer": false,
|
||||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"useadam8bit": true,
|
||||
"validation_config": "validation_default.json",
|
||||
"wandb": false,
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_rate": 50
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.02
|
||||
}
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
def check_git():
|
||||
import subprocess
|
||||
try:
|
||||
result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True)
|
||||
branch = result.stdout.strip()
|
||||
|
||||
result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True)
|
||||
ahead, behind = map(int, result.stdout.split())
|
||||
|
||||
if behind > 0:
|
||||
print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.")
|
||||
elif ahead > 0:
|
||||
print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.")
|
||||
else:
|
||||
print(f"** Your branch '{branch}' is up to date with the remote")
|
||||
except:
|
||||
pass
|
|
@ -430,6 +430,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
output_block_list = {x : sorted(y) for x, y in output_block_list.items()}
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||
|
@ -442,8 +444,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
import logging
|
||||
|
||||
def barename(file):
|
||||
(val, _) = os.path.splitext(os.path.basename(file))
|
||||
return val
|
||||
|
||||
def ext(file):
|
||||
(_, val) = os.path.splitext(os.path.basename(file))
|
||||
return val.lower()
|
||||
|
||||
def same_barename(lhs, rhs):
|
||||
return barename(lhs) == barename(rhs)
|
||||
|
||||
def is_image(file):
|
||||
return ext(file) in {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'}
|
||||
|
||||
def read_text(file):
|
||||
try:
|
||||
with open(file, encoding='utf-8', mode='r') as stream:
|
||||
return stream.read().strip()
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Error reading text file {file}: {e}")
|
||||
|
||||
def read_float(file):
|
||||
try:
|
||||
return float(read_text(file))
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
|
||||
|
||||
import os
|
||||
|
||||
def walk_and_visit(path, visit_fn, context=None):
|
||||
names = [entry.name for entry in os.scandir(path)]
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
for name in names:
|
||||
fullname = os.path.join(path, name)
|
||||
|
||||
if str(name).startswith('.'):
|
||||
continue
|
||||
|
||||
if os.path.isdir(fullname):
|
||||
dirs.append(fullname)
|
||||
else:
|
||||
files.append(fullname)
|
||||
|
||||
subcontext = visit_fn(files, context)
|
||||
|
||||
for subdir in dirs:
|
||||
walk_and_visit(subdir, visit_fn, subcontext)
|
|
@ -9,10 +9,12 @@ import torch
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
from colorama import Fore, Style
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
|
||||
from torch import FloatTensor
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from compel import Compel
|
||||
|
||||
|
||||
def clean_filename(filename):
|
||||
|
@ -84,7 +86,8 @@ class SampleGenerator:
|
|||
batch_size: int,
|
||||
default_seed: int,
|
||||
default_sample_steps: int,
|
||||
use_xformers: bool):
|
||||
use_xformers: bool,
|
||||
use_penultimate_clip_layer: bool):
|
||||
self.log_folder = log_folder
|
||||
self.log_writer = log_writer
|
||||
self.batch_size = batch_size
|
||||
|
@ -92,6 +95,7 @@ class SampleGenerator:
|
|||
self.use_xformers = use_xformers
|
||||
self.show_progress_bars = False
|
||||
self.generate_pretrain_samples = False
|
||||
self.use_penultimate_clip_layer = use_penultimate_clip_layer
|
||||
|
||||
self.default_resolution = default_resolution
|
||||
self.default_seed = default_seed
|
||||
|
@ -198,6 +202,9 @@ class SampleGenerator:
|
|||
compatibility_test=sample_compatibility_test))
|
||||
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
||||
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||
compel = Compel(tokenizer=pipe.tokenizer,
|
||||
text_encoder=pipe.text_encoder,
|
||||
use_penultimate_clip_layer=self.use_penultimate_clip_layer)
|
||||
for batch in batches:
|
||||
prompts = [p.prompt for p in batch]
|
||||
negative_prompts = [p.negative_prompt for p in batch]
|
||||
|
@ -211,8 +218,10 @@ class SampleGenerator:
|
|||
for cfg in self.cfgs:
|
||||
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||
images = pipe(prompt=prompts,
|
||||
negative_prompt=negative_prompts,
|
||||
prompt_embeds = compel(prompts)
|
||||
negative_prompt_embeds = compel(negative_prompts)
|
||||
images = pipe(prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
guidance_scale=cfg,
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
{
|
||||
"documentation": {
|
||||
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by auto_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'manual_data_root'.",
|
||||
"auto_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"manual_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"extra_manual_datasets": "Dictionary of 'name':'path' pairs defining additional validation datasets to load and log. eg { 'santa_suit': '/path/to/captioned_santa_suit_images', 'flamingo_suit': '/path/to/flamingo_suit_images' }",
|
||||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"seed": "The seed to use when running validation and stabilization passes."
|
||||
"seed": "The seed to use when running validation and stabilization passes.",
|
||||
"use_relative_loss": "logs val/loss as negative relative to first pre-train val/loss value"
|
||||
},
|
||||
"validate_training": true,
|
||||
"val_split_mode": "automatic",
|
||||
"val_data_root": null,
|
||||
"val_split_proportion": 0.15,
|
||||
"auto_split_proportion": 0.15,
|
||||
"manual_data_root": null,
|
||||
"extra_manual_datasets" : {},
|
||||
"stabilize_training_loss": false,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
"seed": 555
|
||||
"seed": 555,
|
||||
"use_relative_loss": false
|
||||
}
|
|
@ -4,7 +4,7 @@ echo should be in venv here
|
|||
cd .
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
|
||||
pip install transformers==4.25.1
|
||||
pip install transformers==4.27.1
|
||||
pip install diffusers[torch]==0.13.0
|
||||
pip install pynvml==11.4.1
|
||||
pip install bitsandbytes==0.35.0
|
||||
|
@ -13,7 +13,7 @@ pip install ftfy==6.1.1
|
|||
pip install aiohttp==3.8.3
|
||||
pip install tensorboard>=2.11.0
|
||||
pip install protobuf==3.20.1
|
||||
pip install wandb==0.13.6
|
||||
pip install wandb==0.14.0
|
||||
pip install pyre-extensions==0.0.23
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall
|
||||
|
@ -22,6 +22,7 @@ pip install OmegaConf==2.2.3
|
|||
pip install numpy==1.23.5
|
||||
pip install keyboard
|
||||
pip install lion-pytorch
|
||||
pip install compel~=1.1.3
|
||||
python utils/patch_bnb.py
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
|
Loading…
Reference in New Issue