EveryDream2trainer/Train_JupyterLab.ipynb

408 lines
17 KiB
Plaintext
Raw Normal View History

2023-01-12 16:28:11 -07:00
{
"cells": [
{
"cell_type": "markdown",
"id": "2c831b5b-3025-4177-bef5-25aaec89573a",
"metadata": {},
"source": [
"## Every Dream v2 Jupyter Notebook\n",
2023-01-12 16:28:11 -07:00
"\n",
"### [General Instructions](https://github.com/victorchall/EveryDream2trainer/blob/main/README.md)\n",
2023-01-12 16:28:11 -07:00
"\n",
2023-02-10 16:50:54 -07:00
"### Have you installed EveryDream2?\n",
"[Windows Setup](/doc/SETUP.md)\n",
"\n",
"[Runpod Installer](/installers/Runpod.ipynb)\n",
"\n",
"### What's your plan?\n",
"You will want to have your data prepared before starting, and have a rough training plan in mind. \n",
2023-01-12 16:28:11 -07:00
"\n",
"**Make sure your images are captioned!**\n",
2023-01-12 16:28:11 -07:00
"\n",
"By default the name of your image files are assumed to be captions. If you want to get fancy, there are [more sophisticated techniques](https://github.com/victorchall/EveryDream2trainer/blob/main/doc/DATA.md)\n",
2023-01-12 16:28:11 -07:00
"\n",
"**If this is your first time trying a full fine-tune, start small!** \n",
2023-01-12 16:28:11 -07:00
"\n",
"Pick a single concept and 30-100 images, and see what happens. \n",
2023-01-12 16:28:11 -07:00
"\n",
"Training a small dataset like this is fast, and will give you a feel for how quickly your model (over-)trains depending on your training schedule, captioning schema, knob twiddling. This notebook provides some sensible defaults, there are more questions than answers in how best to fine tune anything. \n",
2023-01-12 16:28:11 -07:00
"\n",
"**_When_ you have questions...**\n",
2023-01-12 16:28:11 -07:00
"\n",
"Come visit us at [EveryDream Discord](https://discord.gg/uheqxU6sXN)"
2023-01-12 16:28:11 -07:00
]
},
{
2023-02-08 17:38:04 -07:00
"cell_type": "markdown",
"id": "7c73894e-3b5e-4268-9f83-ed89bd4569f2",
2023-01-12 16:28:11 -07:00
"metadata": {},
"source": [
"### (Optional) Weights and Biases login. \n",
2023-01-12 16:28:11 -07:00
"\n",
"Paste your token here if you have an account so you can use it to track your training progress. If you don't have an account, you can create one for free at https://wandb.ai/site. Log will use your project name from above. This is a free online logging utility.\n",
2023-01-12 16:28:11 -07:00
"\n",
"Your key is on this page: https://wandb.ai/settings under \"Danger Zone\" \"API Keys\""
2023-01-12 16:28:11 -07:00
]
},
2023-02-08 17:38:04 -07:00
{
"cell_type": "code",
"execution_count": null,
"id": "cdbaf48c-f1e2-458d-b1ee-707f3b71bf61",
2023-02-08 17:38:04 -07:00
"metadata": {},
"outputs": [],
"source": [
"from ipywidgets import *\n",
2023-02-08 17:38:04 -07:00
"\n",
"wandb_token = Password(placeholder=\"Optional Weights & Biases auth token\")\n",
"out = Output()\n",
"def wandb_login(_):\n",
" with out:\n",
" if wandb_token.value:\n",
" !wandb login {wandb_token.value}\n",
" \n",
"wandb_btn = Button(description=\"W&B Login\")\n",
"wandb_btn.on_click(wandb_login)\n",
"print()\n",
"display(VBox([wandb_token, wandb_btn, out]))"
2023-02-08 17:38:04 -07:00
]
},
2023-01-12 16:28:11 -07:00
{
"cell_type": "markdown",
"id": "3d9b0db8-c2b1-4f0a-b835-b6b2ef527019",
2023-01-12 16:28:11 -07:00
"metadata": {},
"source": [
"### HuggingFace Login\n",
"Run the cell below and paste your token into the prompt. You can get your token from your [huggingface account page](https://huggingface.co/settings/tokens).\n",
2023-01-12 16:28:11 -07:00
"\n",
"The token will not show on the screen, just press enter after you paste it."
2023-01-12 16:28:11 -07:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "138b7776-8783-4e1d-920d-cf358809b802",
2023-01-12 16:28:11 -07:00
"metadata": {},
"outputs": [],
2023-01-12 16:28:11 -07:00
"source": [
"from huggingface_hub import notebook_login, hf_hub_download\n",
"import os\n",
"notebook_login()"
2023-01-12 16:28:11 -07:00
]
},
{
"cell_type": "markdown",
"id": "7a96f2af-8c93-4460-aa9e-2ff795fb06ea",
2023-01-12 16:28:11 -07:00
"metadata": {},
"source": [
"#### Then run the following cell to download the base checkpoint (may take a minute)."
2023-01-12 16:28:11 -07:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86b66fe4-c2ca-46fa-813c-8fe390813add",
"metadata": {
"scrolled": true,
"tags": []
},
2023-01-12 16:28:11 -07:00
"outputs": [],
"source": [
"repo=\"panopstor/EveryDream\"\n",
"ckpt_file=\"sd_v1-5_vae.ckpt\"\n",
"\n",
"print(f\"Downloading {ckpt_file} from {repo}\")\n",
"downloaded_model_path = hf_hub_download(repo, ckpt_file, cache_dir=\"/workspace/hfcache\")\n",
"ckpt_name = os.path.splitext(os.path.basename(downloaded_model_path))[0]\n",
"print(f\"Downloaded {ckpt_name} to {downloaded_model_path}\")\n",
"\n",
"if not os.path.exists(f\"ckpt_cache/{ckpt_name}\"):\n",
" print(f\"Converting {ckpt_name} to Diffusers format\")\n",
" %run utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
2023-01-12 16:28:11 -07:00
" --original_config_file v1-inference.yaml \\\n",
" --image_size 512 \\\n",
" --checkpoint_path \"{downloaded_model_path}\" \\\n",
" --prediction_type epsilon \\\n",
" --upcast_attn False \\\n",
" --dump_path \"ckpt_cache/{ckpt_name}\"\n",
"\n",
"print(\"DONE\")"
]
},
{
"cell_type": "markdown",
"id": "f15fcd56-0418-4be1-a5c3-38aa679b1aaf",
"metadata": {},
"source": [
"# Start Training\n",
"Naming your project will help you track what the heck you're doing when you're floating in checkpoint files later.\n",
"\n",
"You may wish to consider adding \"sd1\" or \"sd2v\" or similar to remember what the base was, as you'll also have to tell your inference app what you were using, as its difficult for programs to know what inference YAML to use automatically. For instance, Automatic1111 webui requires you to copy the v2 inference YAML and rename it to match your checkpoint name so it knows how to load the file, tough it assumes SD 1.x compatible. Something to keep in mind if you start training on SD2.1.\n",
"\n",
"`max_epochs`, `sample_steps`, and `save_every_n_epochs` should be tuned to your dataset. I like to generate one or two sets of samples per save, and aim for 5 (give or take 2) saved checkpoints.\n",
"\n",
"Next cell runs training. This will take a while depending on your number of images, repeats, and max_epochs.\n",
"\n",
"You can watch for test images in the logs folder."
]
},
{
"cell_type": "code",
2023-02-10 16:50:54 -07:00
"execution_count": 3,
2023-01-12 16:28:11 -07:00
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
"metadata": {
"scrolled": true,
"tags": []
},
2023-02-10 16:50:54 -07:00
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torch'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m~/Documents/Diffusion/tools/EveryDream2trainer/train.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mshutil\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mautocast\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mGradScaler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtransforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
]
}
],
2023-01-12 16:28:11 -07:00
"source": [
2023-02-10 16:50:54 -07:00
"wandb=\"\"\n",
"if wandb_token and wandb_token.value:\n",
" wandb=\"--wandb\"\n",
" \n",
"%run train.py --config train.json {wandb} \\\n",
"--resume_ckpt \"{ckpt_name}\" \\\n",
"--project_name \"sd1_mymodel\" \\\n",
2023-01-12 16:28:11 -07:00
"--data_root \"input\" \\\n",
"--max_epochs 200 \\\n",
2023-02-08 17:38:04 -07:00
"--sample_steps 150 \\\n",
"--save_every_n_epochs 35 \\\n",
"--lr 1.5e-6 \\\n",
"--lr_scheduler constant"
]
},
{
"cell_type": "markdown",
"id": "ed464c6b-1a8d-48e4-9787-265e8acaac43",
"metadata": {},
"source": [
"### Optionally you can chain trainings together using multiple configurations combined with `resume_ckpt: findlast`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "492350d4-9b2f-4d2a-9641-1f723125b296",
"metadata": {},
"outputs": [],
"source": [
"%run train.py --config chain0.json --project_name \"sd1_chain_a\" --data_root \"input\" --resume_ckpt \"{ckpt_name}\"\n",
"%run train.py --config chain1.json --project_name \"sd1_chain_b\" --data_root \"input\" --resume_ckpt findlast\n",
"%run train.py --config chain2.json --project_name \"sd1_chain_c\" --data_root \"input\" --resume_ckpt findlast"
2023-01-12 16:28:11 -07:00
]
},
{
"cell_type": "markdown",
"id": "f24eee3d-f5df-45f3-9acc-ee0206cfe6b1",
"metadata": {},
"source": [
"# HuggingFace upload\n",
"Use the cell below to upload one or more checkpoints to your personal HuggingFace account, if you want, instead of manually downloading. You should already be authorized to Huggingface by token if you used the download/token cells above.\n",
2023-01-12 16:28:11 -07:00
"\n",
"* You can get your account name from your [HuggingFace account page](https://huggingface.co/settings/account). Look for your \"username\" field and paste it below.\n",
"\n",
"* You only need to setup a repository one time. You can create it here: [Create New HF Model](https://huggingface.co/new) Make sure you write down the repo name you make for future use. You can reuse it later."
2023-01-12 16:28:11 -07:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9df5e1a-3c68-41c0-a4ed-ea0abcd19858",
2023-01-12 16:28:11 -07:00
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"from huggingface_hub import HfApi\n",
"from ipywidgets import *\n",
"\n",
"all_ckpts = [f for f in glob.glob(\"*.ckpt\")]\n",
" \n",
"ckpt_picker = SelectMultiple(options=all_ckpts, layout=Layout(width=\"600px\")) \n",
"hfuser = Text(placeholder='Your HF user name')\n",
"hfrepo = Text(placeholder='Your HF repo name')\n",
2023-01-12 16:28:11 -07:00
"\n",
"api = HfApi()\n",
"upload_btn = Button(description='Upload')\n",
"out = Output()\n",
"\n",
"def upload_ckpts(_):\n",
" repo_id=f\"{hfuser.value}/{hfrepo.value}\"\n",
" with out:\n",
" for ckpt in ckpt_picker.value:\n",
" print(f\"Uploading to HF: huggingface.co/{repo_id}/{ckpt}\")\n",
" response = api.upload_file(\n",
" path_or_fileobj=ckpt,\n",
" path_in_repo=ckpt,\n",
" repo_id=repo_id,\n",
" repo_type=None,\n",
" create_pr=1,\n",
" )\n",
" display(response)\n",
" print(\"DONE\")\n",
" print(\"Go to your repo and accept the PRs this created to see your files\")\n",
"\n",
"upload_btn.on_click(upload_ckpts)\n",
"box = VBox([ckpt_picker, HBox([hfuser, hfrepo]), upload_btn, out])\n",
"\n",
"display(box)"
]
},
{
"cell_type": "markdown",
"id": "c1a00d16-9b84-492f-8e6a-defe71e82b43",
"metadata": {},
"source": [
"# Test inference on your checkpoints"
2023-01-12 16:28:11 -07:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efb1a8cd-6a04-44e5-a770-c23ee247ce82",
2023-01-12 16:28:11 -07:00
"metadata": {},
"outputs": [],
"source": [
"from ipywidgets import *\n",
"from IPython.display import display, clear_output\n",
2023-01-12 16:28:11 -07:00
"import os\n",
"import gc\n",
"import random\n",
"import torch\n",
"import inspect\n",
"\n",
"from torch import autocast\n",
"from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler\n",
"from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"\n",
"checkpoints_ts = []\n",
"for root, dirs, files in os.walk(\".\"):\n",
" for file in files:\n",
" if os.path.basename(file) == \"model_index.json\":\n",
" ts = os.path.getmtime(os.path.join(root,file))\n",
" ckpt = root\n",
" checkpoints_ts.append((ts, root))\n",
"\n",
"checkpoints = [ckpt for (_, ckpt) in sorted(checkpoints_ts, reverse=True)]\n",
"full_width = Layout(width='600px')\n",
"half_width = Layout(width='300px')\n",
"\n",
"checkpoint = Dropdown(options=checkpoints, description='Checkpoint:', layout=full_width)\n",
"prompt = Textarea(value='a photo of ', description='Prompt:', layout=full_width)\n",
"height = IntSlider(value=512, min=256, max=768, step=32, description='Height:', layout=half_width)\n",
"width = IntSlider(value=512, min=256, max=768, step=32, description='Width:', layout=half_width)\n",
"cfg = FloatSlider(value=7.0, min=0.0, max=14.0, step=0.2, description='CFG Scale:', layout=half_width)\n",
"steps = IntSlider(value=30, min=10, max=100, description='Steps:', layout=half_width)\n",
"seed = IntText(value=-1, description='Seed:', layout=half_width)\n",
"generate_btn = Button(description='Generate', layout=full_width)\n",
"out = Output()\n",
"\n",
"def generate(_):\n",
" with out:\n",
" clear_output()\n",
" display(f\"Loading model {checkpoint.value}\")\n",
" actual_seed = seed.value if seed.value != -1 else random.randint(0, 2**30)\n",
"\n",
" text_encoder = CLIPTextModel.from_pretrained(checkpoint.value, subfolder=\"text_encoder\")\n",
" vae = AutoencoderKL.from_pretrained(checkpoint.value, subfolder=\"vae\")\n",
" unet = UNet2DConditionModel.from_pretrained(checkpoint.value, subfolder=\"unet\")\n",
" tokenizer = CLIPTokenizer.from_pretrained(checkpoint.value, subfolder=\"tokenizer\", use_fast=False)\n",
" scheduler = DDIMScheduler.from_pretrained(checkpoint.value, subfolder=\"scheduler\")\n",
" text_encoder.eval()\n",
" vae.eval()\n",
" unet.eval()\n",
"\n",
" text_encoder.to(\"cuda\")\n",
" vae.to(\"cuda\")\n",
" unet.to(\"cuda\")\n",
"\n",
" pipe = StableDiffusionPipeline(\n",
" vae=vae,\n",
" text_encoder=text_encoder,\n",
" tokenizer=tokenizer,\n",
" unet=unet,\n",
" scheduler=scheduler,\n",
" safety_checker=None, # save vram\n",
" requires_safety_checker=None, # avoid nag\n",
" feature_extractor=None, # must be none of no safety checker\n",
" )\n",
"\n",
" pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)\n",
" \n",
" print(inspect.cleandoc(f\"\"\"\n",
" Prompt: {prompt.value}\n",
" Resolution: {width.value}x{height.value}\n",
" CFG: {cfg.value}\n",
" Steps: {steps.value}\n",
" Seed: {actual_seed}\n",
" \"\"\"))\n",
" with autocast(\"cuda\"):\n",
" image = pipe(prompt.value, \n",
" generator=torch.Generator(\"cuda\").manual_seed(actual_seed),\n",
" num_inference_steps=steps.value, \n",
" guidance_scale=cfg.value,\n",
" width=width.value,\n",
" height=height.value\n",
" ).images[0]\n",
" del pipe\n",
" gc.collect()\n",
" with torch.cuda.device(\"cuda\"):\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.ipc_collect()\n",
" display(image)\n",
" \n",
"generate_btn.on_click(generate)\n",
"box = VBox(\n",
" children=[\n",
" checkpoint, prompt, \n",
" HBox([VBox([width, height]), VBox([steps, cfg])]), \n",
" seed, \n",
" generate_btn, \n",
" out]\n",
2023-01-12 16:28:11 -07:00
")\n",
"\n",
"\n",
"display(box)"
2023-01-12 16:28:11 -07:00
]
}
],
"metadata": {
"kernelspec": {
2023-02-08 17:38:04 -07:00
"display_name": "Python 3",
2023-01-12 16:28:11 -07:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-02-08 17:38:04 -07:00
"version": "3.6.10"
2023-01-12 16:28:11 -07:00
},
"vscode": {
"interpreter": {
"hash": "2e677f113ff5b533036843965d6e18980b635d0aedc1c5cebd058006c5afc92a"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}