Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer
This commit is contained in:
commit
0e186a2a12
|
@ -14,3 +14,5 @@
|
|||
.ssh_config
|
||||
*inference*.yaml
|
||||
.idea
|
||||
/.cache
|
||||
/models
|
||||
|
|
|
@ -16,35 +16,79 @@
|
|||
"id": "blaLMSbkPHhG"
|
||||
},
|
||||
"source": [
|
||||
"# EveryDream2 Colab Edition\n",
|
||||
"<p align=\"center\">\n",
|
||||
" <img src=\"https://github.com/victorchall/EveryDream2trainer/blob/562c4341137d1d9f5bf525e6c56fb4b1eefa2b57/doc/ed_logo_comp.jpg?raw=true\" width=\"600\" height=\"300\">\n",
|
||||
"</p>\n",
|
||||
"\n",
|
||||
"Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
|
||||
"<br>\n",
|
||||
"\n",
|
||||
"And join the discord: https://discord.gg/uheqxU6sXN"
|
||||
"---\n",
|
||||
"\n",
|
||||
"<div align=\"center\">\n",
|
||||
" <font size=\"6\" color=\"yellow\">Colab Edition</font>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"<br>\n",
|
||||
"\n",
|
||||
"Check out the **EveryDream2trainer** documentation and runpod/vastai and local setups here: \n",
|
||||
"\n",
|
||||
"[📚 **Documentation**](https://github.com/victorchall/EveryDream2trainer#docs)\n",
|
||||
"\n",
|
||||
"And join our vibrant community on Discord:\n",
|
||||
"\n",
|
||||
"[💬 **Join the Discord**](https://discord.gg/uheqxU6sXN)\n",
|
||||
"\n",
|
||||
"If you find this tool useful, please consider subscribing to the project on Patreon or making a one-time donation on Ko-fi. Your donations keep this project alive as a free open-source tool with ongoing enhancements.\n",
|
||||
"\n",
|
||||
"<br>\n",
|
||||
"\n",
|
||||
"<p align=\"center\">\n",
|
||||
" <a href=\"https://www.patreon.com/everydream\">\n",
|
||||
" <img src=\"https://github.com/victorchall/EveryDream2trainer/raw/main/.github/patreon-medium-button.png?raw=true\" width=\"200\" height=\"50\">\n",
|
||||
" </a>\n",
|
||||
"</p>\n",
|
||||
"\n",
|
||||
"<br>\n",
|
||||
"\n",
|
||||
"<p align=\"center\">\n",
|
||||
" <a href=\"https://ko-fi.com/everydream\">\n",
|
||||
" <img src=\"https://github.com/victorchall/EveryDream2trainer/raw/main/.github/kofibutton_sm.png?raw=true\" width=\"75\" height=\"75\">\n",
|
||||
" </a>\n",
|
||||
"</p>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hAuBbtSvGpau",
|
||||
"cellView": "form"
|
||||
"cellView": "form",
|
||||
"id": "hAuBbtSvGpau"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@markdown # Setup and Install Dependencies\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"from IPython.display import clear_output, display, HTML\n",
|
||||
"import subprocess\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"import time\n",
|
||||
"import os \n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"import PIL\n",
|
||||
"\n",
|
||||
"# Defining function for colored text\n",
|
||||
"def colored(r, g, b, text):\n",
|
||||
" return f\"\\033[38;2;{r};{g};{b}m{text} \\033[38;2;255;255;255m\"\n",
|
||||
"\n",
|
||||
"#@markdown Optional connect Gdrive But strongly recommended\n",
|
||||
"#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
|
||||
"#@markdown This will let you put all your training data and checkpoints directly on your drive. \n",
|
||||
"#@markdown Much faster/easier to continue later, less setup time.\n",
|
||||
"\n",
|
||||
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
|
||||
"Mount_to_Gdrive = True #@param{type:\"boolean\"} \n",
|
||||
"\n",
|
||||
"# Clone the git repository\n",
|
||||
"print(colored(0, 255, 0, 'Cloning git repository...'))\n",
|
||||
"!git clone https://github.com/victorchall/EveryDream2trainer.git\n",
|
||||
"\n",
|
||||
"if Mount_to_Gdrive:\n",
|
||||
|
@ -55,35 +99,37 @@
|
|||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"\n",
|
||||
"# Download Arial font\n",
|
||||
"print(colored(0, 255, 0, 'Downloading Arial font...'))\n",
|
||||
"!wget -O arial.ttf https://raw.githubusercontent.com/matomo-org/travis-scripts/master/fonts/Arial.ttf\n",
|
||||
"\n",
|
||||
"!cp /content/arial.ttf /usr/share/fonts/truetype/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"packages = [\n",
|
||||
" 'transformers==4.27.1',\n",
|
||||
" 'transformers==4.29.2',\n",
|
||||
" 'diffusers[torch]==0.14.0',\n",
|
||||
" 'pynvml==11.4.1',\n",
|
||||
" 'bitsandbytes==0.37.2',\n",
|
||||
" 'ftfy==6.1.1',\n",
|
||||
" 'aiohttp==3.8.4',\n",
|
||||
" 'compel~=1.1.3',\n",
|
||||
" 'protobuf==3.20.3',\n",
|
||||
" 'wandb==0.13.6',\n",
|
||||
" 'pyre-extensions==0.0.23',\n",
|
||||
" '--no-deps xformers==0.0.19',\n",
|
||||
" 'pytorch-lightning==1.9.2',\n",
|
||||
" 'protobuf==3.20.1',\n",
|
||||
" 'wandb==0.15.3',\n",
|
||||
" 'pyre-extensions==0.0.29',\n",
|
||||
" 'xformers==0.0.20',\n",
|
||||
" 'pytorch-lightning==1.6.5',\n",
|
||||
" 'OmegaConf==2.2.3',\n",
|
||||
" 'tensorboard>=2.11.0',\n",
|
||||
" 'tensorrt'\n",
|
||||
" 'wandb',\n",
|
||||
" 'colorama',\n",
|
||||
" 'keyboard',\n",
|
||||
" 'lion-pytorch'\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for package in tqdm(packages, desc='Installing packages', unit='package'):\n",
|
||||
"print(colored(0, 255, 0, 'Installing packages...'))\n",
|
||||
"for package in tqdm(packages, desc='Installing packages', unit='package', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}'):\n",
|
||||
" if isinstance(package, tuple):\n",
|
||||
" package_name, extra_index_url = package\n",
|
||||
" cmd = f\"pip install -q {package_name} --extra-index-url {extra_index_url}\"\n",
|
||||
" cmd = f\"pip install -I -q {package_name} --extra-index-url {extra_index_url}\"\n",
|
||||
" else:\n",
|
||||
" cmd = f\"pip install -q {package}\"\n",
|
||||
" \n",
|
||||
|
@ -92,28 +138,62 @@
|
|||
"clear_output()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Execute Python script\n",
|
||||
"print(colored(0, 255, 0, 'Executing Python script...'))\n",
|
||||
"!python utils/get_yamls.py\n",
|
||||
"clear_output()\n",
|
||||
"## ty Google for cutting out install time by 50%\n",
|
||||
"print(\"DONE! installing dependencies.\")\n",
|
||||
"GPU = !nvidia-smi\n",
|
||||
"print(\"GPU details:\")\n",
|
||||
"for line in GPU:\n",
|
||||
" print(line)\n",
|
||||
"\n",
|
||||
"print(colored(0, 255, 0, \"DONE! installing dependencies.\"))\n",
|
||||
"\n",
|
||||
"# Import pynvml and get GPU details\n",
|
||||
"import pynvml\n",
|
||||
"\n",
|
||||
"pynvml.nvmlInit()\n",
|
||||
"\n",
|
||||
"handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
|
||||
"\n",
|
||||
"gpu_name = pynvml.nvmlDeviceGetName(handle)\n",
|
||||
"gpu_memory = pynvml.nvmlDeviceGetMemoryInfo(handle).total / 1024**3\n",
|
||||
"cuda_version_number = pynvml.nvmlSystemGetCudaDriverVersion_v2()\n",
|
||||
"cuda_version_major = cuda_version_number // 1000\n",
|
||||
"cuda_version_minor = (cuda_version_number % 1000) // 10\n",
|
||||
"cuda_version = f\"{cuda_version_major}.{cuda_version_minor}\"\n",
|
||||
"\n",
|
||||
"pynvml.nvmlShutdown()\n",
|
||||
"\n",
|
||||
"Python_version = !python --version\n",
|
||||
"print(\"\\nPython version:\")\n",
|
||||
"print(Python_version[0])\n",
|
||||
"import torch\n",
|
||||
"print(\"\\nPyTorch version:\")\n",
|
||||
"print(torch.__version__)\n",
|
||||
"import torchvision\n",
|
||||
"print(\"\\nTorchvision version:\")\n",
|
||||
"print(torchvision.__version__)\n",
|
||||
"import xformers\n",
|
||||
"print(\"\\nXFormers version:\")\n",
|
||||
"print(xformers.__version__)\n",
|
||||
"time.sleep(2)\n"
|
||||
"\n",
|
||||
"display(HTML(f\"\"\"\n",
|
||||
"<table style=\"background-color:transparent;\">\n",
|
||||
" <tr>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">Python version:</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{Python_version[0]}</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">GPU Name:</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{gpu_name}</span></td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">PyTorch version:</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{torch.__version__}</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">GPU Memory (GB):</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{gpu_memory:.2f}</span></td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">Torchvision version:</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{torchvision.__version__}</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">CUDA version:</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{cuda_version}</span></td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">XFormers version:</span></td>\n",
|
||||
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{xformers.__version__}</span></td>\n",
|
||||
" </tr>\n",
|
||||
"</table>\n",
|
||||
"\"\"\"))\n",
|
||||
"\n",
|
||||
"time.sleep(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -127,43 +207,45 @@
|
|||
"source": [
|
||||
"#@title Get A Base Model\n",
|
||||
"#@markdown Choose SD1.5, Waifu Diffusion 1.3, SD2.1, or 2.1(512) from the dropdown, or paste your own URL in the box\n",
|
||||
"#@markdown * Alternately you can link to a HF repo using NAME/MODEL, this does not save to your Gdrive, if you want to save an hf model use the direct url\n",
|
||||
"#@markdown * Alternately you can link to an HF repo using NAME/MODEL, this does not save to your Gdrive, if you want to save an HF model, use the direct URL\n",
|
||||
"\n",
|
||||
"#@markdown * Link to a set of diffusers on your Gdrive\n",
|
||||
"\n",
|
||||
"#@markdown * Paste a url, atm there is no support for .safetensors\n",
|
||||
"#@markdown * Paste a URL, atm there is no support for .safetensors\n",
|
||||
"\n",
|
||||
"from IPython.display import clear_output\n",
|
||||
"!mkdir input\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"MODEL_LOCATION = \"sd_v1-5+vae.ckpt\" #@param [\"sd_v1-5+vae.ckpt\", \"hakurei/waifu-diffusion-v1-3\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
|
||||
"MODEL_LOCATION = \"panopstor/EveryDream\" #@param [\"sd_v1-5+vae.ckpt\", \"hakurei/waifu-diffusion-v1-3\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
|
||||
"\n",
|
||||
"if MODEL_LOCATION == \"sd_v1-5+vae.ckpt\":\n",
|
||||
" MODEL_LOCATION = \"panopstor/EveryDream\"\n",
|
||||
"Flag = False\n",
|
||||
"\n",
|
||||
"If_Ckpt = False\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"download_path = \"\"\n",
|
||||
"\n",
|
||||
"if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION: #maybe just add a radio button to download this should work for now\n",
|
||||
" print(\"Downloading \")\n",
|
||||
"if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION:\n",
|
||||
" MODEL_URL = MODEL_LOCATION\n",
|
||||
" print(\"Downloading...\")\n",
|
||||
" !wget $MODEL_LOCATION\n",
|
||||
" clear_output()\n",
|
||||
" print(\"DONE!\")\n",
|
||||
" print(\"Download completed!\")\n",
|
||||
" download_path = os.path.join(os.getcwd(), os.path.basename(MODEL_URL))\n",
|
||||
"\n",
|
||||
"else:\n",
|
||||
" save_name = MODEL_LOCATION\n",
|
||||
" save_name = MODEL_LOCATION\n",
|
||||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"#@markdown * If you chose to link to a .ckpt Select the correct model version in the drop down menu for conversion\n",
|
||||
"\n",
|
||||
"inference_yaml = \" \"\n",
|
||||
"\n",
|
||||
"# Check if the downloaded or copied model is a .ckpt file\n",
|
||||
"#@markdown Is the model 1.5 or 2.1 based?\n",
|
||||
"model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
|
||||
"\n",
|
||||
"if download_path.endswith(\".ckpt\") or MODEL_LOCATION.endswith(\".ckpt\"):\n",
|
||||
" Flag = True\n",
|
||||
" model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
|
||||
" If_Ckpt = True\n",
|
||||
" save_path = download_path\n",
|
||||
" if \".ckpt\" in save_name:\n",
|
||||
" save_name = save_name.replace(\".ckpt\", \"\")\n",
|
||||
|
@ -171,6 +253,7 @@
|
|||
" img_size = 512\n",
|
||||
" upscale_attention = False\n",
|
||||
" prediction_type = \"epsilon\"\n",
|
||||
"\n",
|
||||
" if model_type == \"SD1x\":\n",
|
||||
" inference_yaml = \"v1-inference.yaml\"\n",
|
||||
" elif model_type == \"SD2_512_base\":\n",
|
||||
|
@ -182,7 +265,7 @@
|
|||
" inference_yaml = \"v2-inference-v.yaml\"\n",
|
||||
" img_size = 768\n",
|
||||
"\n",
|
||||
" !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
|
||||
" !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
|
||||
" --original_config_file $inference_yaml \\\n",
|
||||
" --image_size $img_size \\\n",
|
||||
" --checkpoint_path $MODEL_LOCATION \\\n",
|
||||
|
@ -190,13 +273,14 @@
|
|||
" --upcast_attn False \\\n",
|
||||
" --dump_path $save_name\n",
|
||||
"\n",
|
||||
" # Set the save path to the GDrive directory if cache_to_gdrive is True\n",
|
||||
"# Set the save path to the GDrive directory if cache_to_gdrive is True\n",
|
||||
"if If_Ckpt:\n",
|
||||
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
|
||||
"\n",
|
||||
"if Flag:\n",
|
||||
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
|
||||
"if inference_yaml != \" \":\n",
|
||||
" print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
|
||||
"print(\"Model \" + save_name + \" will be used!\")\n"
|
||||
" print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
|
||||
"\n",
|
||||
"print(\"Model \" + save_name + \" will be used!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -245,20 +329,27 @@
|
|||
"#@markdown * Name your project so you can find it in your logs\n",
|
||||
"Project_Name = \"My_Project\" #@param{type: 'string'}\n",
|
||||
"\n",
|
||||
"# Load the JSON file\n",
|
||||
"with open('optimizer.json', 'r') as file:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if model_type == 'SD2_512_base' or model_type == 'SD21':\n",
|
||||
" file_path = \"/content/EveryDream2trainer/optimizerSD21.json\"\n",
|
||||
"else:\n",
|
||||
" file_path = \"/content/EveryDream2trainer/optimizer.json\"\n",
|
||||
"\n",
|
||||
"with open(file_path, 'r') as file:\n",
|
||||
" data = json.load(file)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data and not wreck the model. \n",
|
||||
"#@markdown Leave this default unless you are very comfortable with training and know what you are doing.\n",
|
||||
"Learning_Rate = 1e-6 #@param{type: 'number'}\n",
|
||||
"#@markdown * chosing this will allow you to ignore any settings specific to the text encode and will match it with the Unets settings, recommended for beginers.\n",
|
||||
"Match_text_to_Unet = False #@param{type:\"boolean\"}\n",
|
||||
"Text_lr = 0.5e-6 #@param {type:\"number\"}\n",
|
||||
"Text_lr = 5e-7 #@param {type:\"number\"}\n",
|
||||
"#@markdown * A learning rate scheduler can change your learning rate as training progresses.\n",
|
||||
"#@markdown * I recommend sticking with constant until you are comfortable with general training. \n",
|
||||
"Schedule = \"constant\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
|
||||
"Text_lr_scheduler = \"constant\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
|
||||
"Schedule = \"linear\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
|
||||
"Text_lr_scheduler = \"linear\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
|
||||
"#@markdown * warm up steps are useful for validation and cosine lrs\n",
|
||||
"lr_warmup_steps = 0 #@param{type:\"integer\"}\n",
|
||||
"lr_decay_steps = 0 #@param {type:\"number\"} \n",
|
||||
|
@ -280,7 +371,7 @@
|
|||
"data['text_encoder_overrides']['lr_decay_steps'] = Text_lr_decay_steps\n",
|
||||
"\n",
|
||||
"# Save the updated JSON data back to the file\n",
|
||||
"with open('optimizer.json', 'w') as file:\n",
|
||||
"with open(file_path, 'w') as file:\n",
|
||||
" json.dump(data, file, indent=4)\n",
|
||||
"\n",
|
||||
"#@markdown * Resolution to train at (recommend 512). Higher resolution will require lower batch size (below).\n",
|
||||
|
@ -302,11 +393,12 @@
|
|||
"#@markdown * Location on your Gdrive where your training images are.\n",
|
||||
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
|
||||
"\n",
|
||||
"model = save_name\n",
|
||||
"if not resume:\n",
|
||||
" model = save_name\n",
|
||||
"\n",
|
||||
"#@markdown * Max Epochs to train for, this defines how many total times all your training data is used. Default of 100 is a good start if you are training ~30-40 images of one subject. If you have 100 images, you can reduce this to 40-50 and so forth.\n",
|
||||
"\n",
|
||||
"Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:5}\n",
|
||||
"Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:1}\n",
|
||||
"\n",
|
||||
"#@markdown * How often to save checkpoints.\n",
|
||||
"Save_every_N_epoch = 20 #@param{type:\"integer\"}\n",
|
||||
|
@ -361,7 +453,7 @@
|
|||
"#@markdown use validation with wandb\n",
|
||||
"\n",
|
||||
"validatation = False #@param{type:\"boolean\"}\n",
|
||||
"\n",
|
||||
"Hide_Warnings = False #@param {type:\"boolean\"}\n",
|
||||
"\n",
|
||||
"extensions = ['.zip', '.7z', '.rar', '.tgz']\n",
|
||||
"uncompressed_dir = 'Training_Data'\n",
|
||||
|
@ -394,6 +486,11 @@
|
|||
"if validatation:\n",
|
||||
" validate = \"--validation_config validation_default.json\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if Hide_Warnings:\n",
|
||||
" import warnings\n",
|
||||
" warnings.filterwarnings(\"ignore\")\n",
|
||||
"\n",
|
||||
"wandb_settings = \"\"\n",
|
||||
"if wandb_token:\n",
|
||||
" !rm /root/.netrc\n",
|
||||
|
@ -446,7 +543,7 @@
|
|||
" --zero_frequency_noise_ratio $zero_frequency_noise\n",
|
||||
"\n",
|
||||
"# Finish the training process\n",
|
||||
"clear_output()\n",
|
||||
"# clear_output()\n",
|
||||
"time.sleep(2)\n",
|
||||
"print(\"Training is complete, select a model to start training again\")\n",
|
||||
"time.sleep(2)\n",
|
||||
|
@ -456,8 +553,7 @@
|
|||
" time.sleep(40)\n",
|
||||
" runtime.unassign()\n",
|
||||
"\n",
|
||||
"os.kill(os.getpid(), 9)\n",
|
||||
"\n"
|
||||
"os.kill(os.getpid(), 9)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -519,15 +615,24 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Optional NoteBook Features, read all the documentation in /content/EveryDream2trainer/doc before proceeding."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fzXLJVC6OCeP"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Optional NoteBook Features, read all the documentation in /content/EveryDream2trainer/doc before proceeding."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"colab": {
|
||||
"background_save": true
|
||||
},
|
||||
"id": "BafdWaYymg0O"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Remove logs for samples when training (optional) run before training\n",
|
||||
"file_path = \"/content/EveryDream2trainer/utils/sample_generator.py\"\n",
|
||||
|
@ -550,170 +655,8 @@
|
|||
"with open(file_path, \"w\") as file:\n",
|
||||
" file.write(content)\n",
|
||||
"\n",
|
||||
"print(\"The specified code block has been deleted.\")\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "BafdWaYymg0O",
|
||||
"cellView": "form"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"#@title train.json Editor { display-mode: \"form\" }\n",
|
||||
"#title json Editor for ED2\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"data = {\n",
|
||||
" \"disable_textenc_training\": False,\n",
|
||||
" \"disable_xformers\": False,\n",
|
||||
" \"disable_amp\": False,\n",
|
||||
" \"save_optimizer\": False,\n",
|
||||
" \"gradient_checkpointing\": True,\n",
|
||||
" \"wandb\": False,\n",
|
||||
" \"write_schedule\": False,\n",
|
||||
" \"rated_dataset\": False,\n",
|
||||
" \"batch_size\": 10,\n",
|
||||
" \"ckpt_every_n_minutes\": None,\n",
|
||||
" \"clip_grad_norm\": None,\n",
|
||||
" \"clip_skip\": 0,\n",
|
||||
" \"cond_dropout\": 0.04,\n",
|
||||
" \"data_root\": \"X:\\\\my_project_data\\\\project_abc\",\n",
|
||||
" \"flip_p\": 0.0,\n",
|
||||
" \"gpuid\": 0,\n",
|
||||
" \"grad_accum\": 1,\n",
|
||||
" \"logdir\": \"logs\",\n",
|
||||
" \"log_step\": 25,\n",
|
||||
" \"lr\": 1.5e-06,\n",
|
||||
" \"lr_decay_steps\": 0,\n",
|
||||
" \"lr_scheduler\": \"constant\",\n",
|
||||
" \"lr_warmup_steps\": None,\n",
|
||||
" \"max_epochs\": 30,\n",
|
||||
" \"optimizer_config\": \"optimizer.json\",\n",
|
||||
" \"project_name\": \"project_abc\",\n",
|
||||
" \"resolution\": 512,\n",
|
||||
" \"resume_ckpt\": \"sd_v1-5_vae\",\n",
|
||||
" \"run_name\": None,\n",
|
||||
" \"sample_prompts\": \"sample_prompts.txt\",\n",
|
||||
" \"sample_steps\": 300,\n",
|
||||
" \"save_ckpt_dir\": None,\n",
|
||||
" \"save_ckpts_from_n_epochs\": 0,\n",
|
||||
" \"save_every_n_epochs\": 20,\n",
|
||||
" \"seed\": 555,\n",
|
||||
" \"shuffle_tags\": False,\n",
|
||||
" \"validation_config\": \"validation_default.json\",\n",
|
||||
" \"rated_dataset_target_dropout_percent\": 50,\n",
|
||||
" \"zero_frequency_noise_ratio\": 0.02\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
"#@markdown JSON Parameters\n",
|
||||
"findlast = \"\" \n",
|
||||
"Resume_Last_Training_session = False #@param {type:\"boolean\"}\n",
|
||||
"findlast == Resume_Last_Training_session\n",
|
||||
"disable_textenc_training = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"disable_textenc_training\"] = disable_textenc_training\n",
|
||||
"disable_xformers = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"disable_xformers\"] = disable_xformers\n",
|
||||
"gradient_checkpointing = True #@param {type:\"boolean\"}\n",
|
||||
"data[\"gradient_checkpointing\"] = gradient_checkpointing\n",
|
||||
"save_optimizer = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"save_optimizer\"] = save_optimizer \n",
|
||||
"scale_lr = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"scale_lr\"] = scale_lr\n",
|
||||
"shuffle_tags = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"shuffle_tags\"] = shuffle_tags\n",
|
||||
"wandb = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"wandb\"] = wandb\n",
|
||||
"write_schedule = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"write_schedule\"] = write_schedule\n",
|
||||
"rated_dataset = False #@param {type:\"boolean\"}\n",
|
||||
"data[\"rated_dataset\"] = rated_dataset \n",
|
||||
"batch_size = 8 #@param {type:\"integer\"}\n",
|
||||
"data[\"batch_size\"] = batch_size\n",
|
||||
"ckpt_every_n_minutes = None #@param {type:\"raw\"}\n",
|
||||
"data[\"ckpt_every_n_minutes\"] = ckpt_every_n_minutes\n",
|
||||
"clip_grad_norm = None #@param {type:\"raw\"}\n",
|
||||
"data[\"clip_grad_norm\"] = clip_grad_norm\n",
|
||||
"clip_skip = 0 #@param {type:\"integer\"}\n",
|
||||
"data[\"clip_skip\"] = clip_skip\n",
|
||||
"cond_dropout = 0.04 #@param {type:\"number\"}\n",
|
||||
"data[\"cond_dropout\"] = cond_dropout\n",
|
||||
"data_root = \"X:\\\\my_project_data\\\\project_abc\" #@param {type:\"string\"}\n",
|
||||
"data[\"data_root\"] = data_root\n",
|
||||
"flip_p = 0.0 #@param {type:\"number\"}\n",
|
||||
"data[\"flip_p\"] = flip_p\n",
|
||||
"grad_accum = 1 #@param {type:\"integer\"}\n",
|
||||
"data[\"grad_accum\"] = grad_accum\n",
|
||||
"logdir = \"logs\" #@param {type:\"string\"}\n",
|
||||
"data[\"logdir\"] = logdir\n",
|
||||
"log_step = 25 #@param {type:\"integer\"}\n",
|
||||
"data[\"log_step\"] = log_step\n",
|
||||
"lr = 1.5e-06 #@param {type:\"number\"}\n",
|
||||
"data[\"lr\"] = lr\n",
|
||||
"lr_decay_steps = 0 #@param {type:\"integer\"}\n",
|
||||
"data[\"lr_decay_steps\"] = lr_decay_steps\n",
|
||||
"lr_scheduler = \"constant\" #@param {type:\"string\"}\n",
|
||||
"data[\"lr_scheduler\"] = lr_scheduler\n",
|
||||
"lr_warmup_steps = None #@param {type:\"raw\"}\n",
|
||||
"data[\"lr_warmup_steps\"] = lr_warmup_steps\n",
|
||||
"max_epochs = 100 #@param {type:\"integer\"}\n",
|
||||
"data[\"max_epochs\"] = max_epochs\n",
|
||||
"optimizer_config = \"optimizer.json\" #@param {type:\"string\"}\n",
|
||||
"data[\"optimizer_config\"] = optimizer_config\n",
|
||||
"project_name = \"project_abc\" #@param {type:\"string\"}\n",
|
||||
"data[\"project_name\"] = project_name\n",
|
||||
"resolution = 512 #@param {type:\"integer\"}\n",
|
||||
"data[\"resolution\"] = resolution\n",
|
||||
"resume_ckpt = \"sd_v1-5_vae\" #@param {type:\"string\"}\n",
|
||||
"if findlast:\n",
|
||||
" resume_ckpt = \"findlast\"\n",
|
||||
"data[\"resume_ckpt\"] = resume_ckpt\n",
|
||||
"run_name = None #@param {type:\"raw\"}\n",
|
||||
"data[\"run_name\"] = run_name\n",
|
||||
"sample_prompts = \"sample_prompts.txt\" #@param [\"sample_prompts.txt\", \"sample_prompts.json\"]\n",
|
||||
"data[\"sample_prompts\"] = sample_prompts\n",
|
||||
"sample_steps = 300 #@param {type:\"integer\"}\n",
|
||||
"data[\"sample_steps\"] = sample_steps\n",
|
||||
"save_ckpt_dir = None #@param {type:\"raw\"}\n",
|
||||
"data[\"save_ckpt_dir\"] = save_ckpt_dir\n",
|
||||
"save_ckpts_from_n_epochs = 0 #@param {type:\"integer\"}\n",
|
||||
"data[\"save_ckpts_from_n_epochs\"] = save_ckpts_from_n_epochs\n",
|
||||
"save_every_n_epochs = 20 #@param {type:\"integer\"}\n",
|
||||
"data[\"save_every_n_epochs\"] = save_every_n_epochs\n",
|
||||
"seed = 555 #@param {type:\"integer\"}\n",
|
||||
"data[\"seed\"] = seed\n",
|
||||
"validation_config = \"validation_default.json\" #@param {type:\"string\"}\n",
|
||||
"data[\"validation_config\"] = validation_config\n",
|
||||
"rated_dataset_target_dropout_percent = 50 #@param {type:\"integer\"}\n",
|
||||
"data[\"rated_dataset_target_dropout_percent\"] = rated_dataset_target_dropout_percent\n",
|
||||
"zero_frequency_noise_ratio = 0.02 #@param {type:\"number\"}\n",
|
||||
"data[\"zero_frequency_noise_ratio\"] = zero_frequency_noise_ratio\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Display the modified JSON data\n",
|
||||
"print(\"Modified JSON data:\")\n",
|
||||
"print(json.dumps(data, indent=2))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Save the modified JSON data to a file\n",
|
||||
"filename = \"train.json\" #@param {type:\"string\"}\n",
|
||||
"variable_name = \"\" #@param {type:\"string\"}\n",
|
||||
"\n",
|
||||
"with open(filename, 'w') as file:\n",
|
||||
" json.dump(data, file, indent=2)\n",
|
||||
"\n",
|
||||
"print(f\"Modified JSON data saved to '{filename}'.\")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "d20kz8EtosWM"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
"print(\"The specified code block has been deleted.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -725,7 +668,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Alternate startup script\n",
|
||||
"#@markdown * Edit train.json to setup your paramaters\n",
|
||||
"#@markdown * Edit train.json or chain0.json to setup your paramaters\n",
|
||||
"\n",
|
||||
"#@markdown * Edit using a chain length of 0 will use train.json\n",
|
||||
"\n",
|
||||
|
@ -733,8 +676,6 @@
|
|||
"\n",
|
||||
"#@markdown * make sure to check each confguration you will need 1 Json per chain length 3 are provided\n",
|
||||
"\n",
|
||||
"#@markdown * make sure your .Json contain the line Notebook: true\n",
|
||||
"\n",
|
||||
"#@markdown * your locations in the .json can be done in this format /content/drive/MyDrive/ - then the sub folder you wish to use\n",
|
||||
"\n",
|
||||
"%cd /content/EveryDream2trainer\n",
|
||||
|
@ -748,18 +689,26 @@
|
|||
" l -= 1\n",
|
||||
" I =+ 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Need some tools to Manage your large datasets check out https://github.com/victorchall/EveryDream for some usefull tools and captioner"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ls6mX94trxZV"
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"gpuType": "T4",
|
||||
"include_colab_link": true
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
|
@ -774,4 +723,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -248,7 +248,7 @@ def __get_all_aspects():
|
|||
]
|
||||
|
||||
|
||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
|
||||
def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
|
||||
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
|
||||
if x <= 1:
|
||||
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)
|
||||
|
|
|
@ -15,15 +15,17 @@ limitations under the License.
|
|||
"""
|
||||
import bisect
|
||||
import logging
|
||||
import os.path
|
||||
from collections import defaultdict
|
||||
import math
|
||||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from typing import List, Dict
|
||||
|
||||
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||
import PIL.Image
|
||||
|
||||
from utils.first_fit_decreasing import first_fit_decreasing
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
||||
class DataLoaderMultiAspect():
|
||||
|
@ -34,9 +36,10 @@ class DataLoaderMultiAspect():
|
|||
seed: random seed
|
||||
batch_size: number of images per batch
|
||||
"""
|
||||
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
|
||||
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1, grad_accum=1):
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum = grad_accum
|
||||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
|
@ -103,14 +106,18 @@ class DataLoaderMultiAspect():
|
|||
|
||||
buckets = {}
|
||||
batch_size = self.batch_size
|
||||
grad_accum = self.grad_accum
|
||||
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
target_wh = image_caption_pair.target_wh
|
||||
|
||||
if (target_wh[0],target_wh[1]) not in buckets:
|
||||
buckets[(target_wh[0],target_wh[1])] = []
|
||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||
bucket_key = (image_caption_pair.batch_id,
|
||||
image_caption_pair.target_wh[0],
|
||||
image_caption_pair.target_wh[1])
|
||||
if bucket_key not in buckets:
|
||||
buckets[bucket_key] = []
|
||||
buckets[bucket_key].append(image_caption_pair)
|
||||
|
||||
# handle runts by randomly duplicating items
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
|
@ -125,13 +132,19 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
items: list[ImageTrainItem] = []
|
||||
for bucket in buckets:
|
||||
items.extend(buckets[bucket])
|
||||
# handle batch_id
|
||||
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
||||
batch_size=batch_size,
|
||||
grad_accum=grad_accum)
|
||||
|
||||
effective_batch_size = batch_size * grad_accum
|
||||
items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Picks a random subset of all images
|
||||
|
@ -174,4 +187,53 @@ class DataLoaderMultiAspect():
|
|||
self.ratings_summed: list[float] = []
|
||||
for item in self.prepared_train_data:
|
||||
self.rating_overall_sum += item.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
|
||||
|
||||
def chunk(l: List, chunk_size) -> List:
|
||||
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
|
||||
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
|
||||
|
||||
def unchunk(chunked_list: List):
|
||||
return [i for c in chunked_list for i in c]
|
||||
|
||||
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
|
||||
batch_ids = [k[0] for k in buckets.keys()]
|
||||
items_by_batch_id = {}
|
||||
for batch_id in batch_ids:
|
||||
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
|
||||
return items_by_batch_id
|
||||
|
||||
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
|
||||
batch_size: int,
|
||||
grad_accum: int) -> List[ImageTrainItem]:
|
||||
# precondition: items_by_batch_id has no incomplete batches
|
||||
assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values()))
|
||||
# ensure we don't mix up aspect ratios by treating each chunk of batch_size images as
|
||||
# a single unit to pass to first_fit_decreasing()
|
||||
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
|
||||
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
||||
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
|
||||
batch_size=grad_accum,
|
||||
filler_items=filler_items)
|
||||
|
||||
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
|
||||
return items
|
||||
|
||||
def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List:
|
||||
"""
|
||||
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
|
||||
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
|
||||
"""
|
||||
|
||||
# chunk by effective batch size
|
||||
chunks = chunk(l, chunk_size)
|
||||
# preserve last chunk as last if it is incomplete
|
||||
last_chunk = None
|
||||
if len(chunks[-1]) < chunk_size:
|
||||
last_chunk = chunks.pop(-1)
|
||||
randomizer.shuffle(chunks)
|
||||
if last_chunk is not None:
|
||||
chunks.append(last_chunk)
|
||||
l = unchunk(chunks)
|
||||
return l
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import os
|
||||
import logging
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from functools import total_ordering
|
||||
from attrs import define, field, Factory
|
||||
from attrs import define, field
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
from typing import Iterable
|
||||
|
@ -50,6 +47,7 @@ class ImageConfig:
|
|||
rating: float = None
|
||||
max_caption_length: int = None
|
||||
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
|
||||
batch_id: str = None
|
||||
|
||||
# Options
|
||||
multiply: float = None
|
||||
|
@ -70,6 +68,7 @@ class ImageConfig:
|
|||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||
flip_p=overlay(other.flip_p, self.flip_p),
|
||||
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
||||
batch_id=overlay(other.batch_id, self.batch_id)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -84,6 +83,7 @@ class ImageConfig:
|
|||
cond_dropout=data.get("cond_dropout"),
|
||||
flip_p=data.get("flip_p"),
|
||||
shuffle_tags=data.get("shuffle_tags"),
|
||||
batch_id=data.get("batch_id")
|
||||
)
|
||||
|
||||
# Alternatively parse from dedicated `caption` attribute
|
||||
|
@ -168,6 +168,8 @@ class Dataset:
|
|||
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
|
||||
if 'local.yml' in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||
if 'batch_id.txt' in fileset:
|
||||
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
|
||||
|
||||
result = ImageConfig.fold(cfgs)
|
||||
if 'shuffle_tags.txt' in fileset:
|
||||
|
@ -262,6 +264,7 @@ class Dataset:
|
|||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout,
|
||||
shuffle_tags=config.shuffle_tags,
|
||||
batch_id=config.batch_id
|
||||
)
|
||||
items.append(item)
|
||||
except Exception as e:
|
||||
|
|
|
@ -124,7 +124,7 @@ class ImageTrainItem:
|
|||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
"""
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
image: PIL.Image,
|
||||
caption: ImageCaption,
|
||||
aspects: list[float],
|
||||
|
@ -133,6 +133,7 @@ class ImageTrainItem:
|
|||
multiplier: float=1.0,
|
||||
cond_dropout=None,
|
||||
shuffle_tags=False,
|
||||
batch_id: str=None
|
||||
):
|
||||
self.caption = caption
|
||||
self.aspects = aspects
|
||||
|
@ -143,6 +144,8 @@ class ImageTrainItem:
|
|||
self.multiplier = multiplier
|
||||
self.cond_dropout = cond_dropout
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.batch_id = batch_id or DEFAULT_BATCH_ID
|
||||
self.target_wh = None
|
||||
|
||||
self.image_size = None
|
||||
if image is None or len(image) == 0:
|
||||
|
@ -351,3 +354,6 @@ class ImageTrainItem:
|
|||
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
DEFAULT_BATCH_ID = "default_batch"
|
||||
|
|
|
@ -151,6 +151,17 @@ Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/z
|
|||
|
||||
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
|
||||
|
||||
## Keeping images together (custom batching)
|
||||
|
||||
If you have a subset of your dataset that expresses the same style or concept, training quality may be improved by putting all of these images through the trainer together in the same batch or batches, instead of the default behaviour (which is to shuffle them randomly throughout the entire dataset).
|
||||
|
||||
To control this, put a file called `batch_id.txt` into a folder to give a unique name to the training data in this folder. For example, if you have a bunch of images of horses and you are trying to train them as a single concept, you can assign a unique name such as "my_horses" to these images by putting the word `my_horses` inside `batch_id.txt` in your folder with horse images.
|
||||
|
||||
> Note that because this creates extra aspect ratio buckets, you need to be very careful about correlating the number of images to your training batch size. Aim to have an exact multiple of `batch_size` images at each aspect ratio. For example, if your `batch_size` is 6 and you have images with aspect ratios 4:3, 3:4, and 9:16, you should add or delete images until you have an exact multiple of 6 images (i.e. 6, 12, 28, ...) for each aspect ratio. If you do not do this, the bucketer will duplicate images to fill up each aspect ratio bucket. You'll probably also want to use manual validation to prevent the validator from messing this up, too.
|
||||
|
||||
If you are using `.yaml` files for captioning, you can alternatively add a `batch_id: ` entry to either `local.yaml` or the individual images' `.yaml` files. Note that neither `.yaml` nor `batch_id.txt` files act recursively: they do not apply to subfolders.
|
||||
|
||||
|
||||
# Stuff you probably don't need to mess with, but well here it is:
|
||||
|
||||
|
||||
|
|
16
doc/SETUP.md
16
doc/SETUP.md
|
@ -48,6 +48,18 @@ Double check your python version again after setup by running these two commands
|
|||
|
||||
Again, this should show 3.10.x
|
||||
|
||||
## Docker container
|
||||
## Local docker container
|
||||
|
||||
`docker run -it -p 8888:8888 -p 6006:6006 --gpus all -e JUPYTER_PASSWORD=test1234 -t ghcr.io/victorchall/everydream2trainer:nightly`
|
||||
```sh
|
||||
docker compose up
|
||||
```
|
||||
|
||||
And you can either get a shell via:
|
||||
```sh
|
||||
docker exec -it everydream2trainer-docker-everydream2trainer-1 /bin/bash
|
||||
```
|
||||
|
||||
Or go to your browser and hit `http://localhost:8888`. The web password is
|
||||
`test1234` but you can change that in `docker-compose.yml`.
|
||||
|
||||
Your current source directory will be moutned to the Jupyter notebook.
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
version: '3.8'
|
||||
services:
|
||||
everydream2trainer:
|
||||
image: ghcr.io/victorchall/everydream2trainer:nightly
|
||||
ports:
|
||||
- "127.0.0.1:8888:8888"
|
||||
- "127.0.0.1:6006:6006"
|
||||
environment:
|
||||
- JUPYTER_PASSWORD=test1234
|
||||
volumes:
|
||||
- .:/workspace/EveryDream2trainer
|
||||
- ./.cache:/root/.cache
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
|
@ -87,10 +87,8 @@ ARG CACHEBUST=1
|
|||
RUN git clone https://github.com/victorchall/EveryDream2trainer
|
||||
|
||||
WORKDIR /workspace/EveryDream2trainer
|
||||
RUN python utils/get_yamls.py && \
|
||||
mkdir -p logs && mkdir -p input
|
||||
|
||||
ADD welcome.txt /
|
||||
ADD start.sh /
|
||||
RUN chmod +x /start.sh
|
||||
CMD [ "/start.sh" ]
|
||||
CMD [ "/start.sh" ]
|
||||
|
|
|
@ -2,6 +2,12 @@
|
|||
cat /welcome.txt
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
if [[ ! -f "v2-inference-v.yaml" ]]; then
|
||||
python utils/get_yamls.py
|
||||
fi
|
||||
|
||||
mkdir -p logs input
|
||||
|
||||
# RunPod SSH
|
||||
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]
|
||||
then
|
||||
|
|
|
@ -303,6 +303,15 @@ class EveryDreamOptimizer():
|
|||
)
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = torch.optim.AdamW
|
||||
if "dowg" in optimizer_name:
|
||||
# coordinate_dowg, scalar_dowg require no additional parameters. Epsilon is overrideable but is unnecessary in all stable diffusion training situations.
|
||||
import dowg
|
||||
if optimizer_name == "coordinate_dowg":
|
||||
opt_class = dowg.CoordinateDoWG
|
||||
elif optimizer_name == "scalar_dowg":
|
||||
opt_class = dowg.ScalarDoWG
|
||||
else:
|
||||
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are coordinate_dowg and scalar_dowg")
|
||||
elif optimizer_name in ["dadapt_adam", "dadapt_lion", "dadapt_sgd"]:
|
||||
import dadaptation
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ xformers==0.0.20
|
|||
pytorch-lightning==1.6.5
|
||||
OmegaConf==2.2.3
|
||||
numpy==1.23.5
|
||||
dowg
|
||||
lion-pytorch
|
||||
compel~=1.1.3
|
||||
OmegaConf==2.2.3
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
import unittest
|
||||
|
||||
from utils.first_fit_decreasing import first_fit_decreasing
|
||||
|
||||
class TestFirstFitDecreasing(unittest.TestCase):
|
||||
|
||||
def test_single_basic(self):
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=3)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
input = [[1, 2, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 2, 3])
|
||||
|
||||
def test_multi_basic(self):
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=3)
|
||||
self.assertEqual(output, [1, 1, 1, 2, 2, 1])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
|
||||
|
||||
input = [[1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [2, 2, 1, 1])
|
||||
|
||||
def test_multi_complex(self):
|
||||
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 1, 3, 3, 1, 1, 2, 2, 3])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=3)
|
||||
self.assertEqual(output, [1, 1, 1, 3, 3, 3, 2, 2, 1])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 3, 3, 3, 2, 2])
|
||||
|
||||
input = [[1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [3, 3, 3, 2, 1, 1, 2])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 3, 2, 2, 2, 3, 5, 5, 3])
|
||||
|
||||
def test_filler_bucket(self):
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=2, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=3, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 2, 2, 9, 1, 9])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
|
||||
|
||||
input = [[1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
|
||||
self.assertEqual(output, [2, 2, 9, 9, 1, 1])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
|
||||
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 9, 3, 3, 3, 9, 2, 2, 2, 5, 5])
|
21
train.py
21
train.py
|
@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
|
|||
|
||||
from data.every_dream import EveryDreamBatch, build_torch_dataloader
|
||||
from data.every_dream_validation import EveryDreamValidator
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
|
|||
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
|
||||
warn_bucket_dupe_ratio = 0.5
|
||||
|
||||
ar_buckets = set([tuple(i.target_wh) for i in items])
|
||||
def make_bucket_key(item):
|
||||
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
|
||||
|
||||
ar_buckets = set(make_bucket_key(i) for i in items)
|
||||
for ar_bucket in ar_buckets:
|
||||
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
|
||||
count = len([i for i in items if make_bucket_key(i) == ar_bucket])
|
||||
runt_size = batch_size - (count % batch_size)
|
||||
bucket_dupe_ratio = runt_size / count
|
||||
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
|
||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
|
||||
aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
|
||||
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
|
||||
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
|
||||
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
|
||||
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
|
||||
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
|
||||
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
|
||||
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
|
||||
f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
|
||||
|
||||
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
|
||||
|
@ -548,6 +552,7 @@ def main(args):
|
|||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
batch_size=args.batch_size,
|
||||
grad_accum=args.grad_accum
|
||||
)
|
||||
|
||||
train_batch = EveryDreamBatch(
|
||||
|
@ -785,15 +790,15 @@ def main(args):
|
|||
lr_textenc = ed_optimizer.get_textenc_lr()
|
||||
loss_log_step = []
|
||||
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=lr_unet, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
images_per_sec_log_step = []
|
||||
if args.amp:
|
||||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparameter/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
|
||||
|
||||
logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
def first_fit_decreasing(input_items: List[List], batch_size: int, filler_items: List=[]) -> List:
|
||||
"""
|
||||
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
|
||||
lists end up in the same batch. Pad out too-short batches by taking items from the filler_items list, if available.
|
||||
|
||||
@return flattened list of all items in input_items and filler_items, arranged such that, as much as possible, items
|
||||
that are in the same input list end up in the same batch.
|
||||
"""
|
||||
|
||||
def sort_by_length(items: List[List]) -> List[List]:
|
||||
return sorted(items, key=lambda x: len(x))
|
||||
|
||||
remaining = input_items
|
||||
output = []
|
||||
while remaining:
|
||||
remaining = sort_by_length(remaining)
|
||||
longest = remaining.pop()
|
||||
if len(longest) == 0:
|
||||
continue
|
||||
if len(longest) >= batch_size:
|
||||
output.append(longest[0:batch_size])
|
||||
del longest[0:batch_size]
|
||||
if len(longest)>0:
|
||||
remaining.append(longest)
|
||||
else:
|
||||
# need to build this chunk by combining multiple
|
||||
combined = longest
|
||||
while True:
|
||||
fill_length = batch_size - len(combined)
|
||||
if fill_length == 0:
|
||||
break
|
||||
|
||||
if len(remaining) == 0 and len(filler_items) == 0:
|
||||
break
|
||||
|
||||
from_filler_bucket = filler_items[0:fill_length]
|
||||
if len(from_filler_bucket) > 0:
|
||||
del filler_items[0:fill_length]
|
||||
combined.extend(from_filler_bucket)
|
||||
continue
|
||||
|
||||
filler = next((r for r in remaining if len(r) <= fill_length), None)
|
||||
if filler is not None:
|
||||
remaining.remove(filler)
|
||||
combined.extend(filler)
|
||||
else:
|
||||
# steal from the next longest
|
||||
next_longest = remaining.pop()
|
||||
combined.extend(next_longest[0:fill_length])
|
||||
del next_longest[0:fill_length]
|
||||
if len(next_longest) > 0:
|
||||
remaining.append(next_longest)
|
||||
output.append(combined)
|
||||
|
||||
output.append(filler_items)
|
||||
return [i for o in output for i in o]
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue