Add some widgets for inference and HF upload

And tidy up the markdown a bit
This commit is contained in:
Augusto de la Torre 2023-02-09 16:02:32 +01:00
parent e98aa90288
commit e621ba4b6a
1 changed files with 184 additions and 62 deletions

View File

@ -13,19 +13,28 @@
"\n",
"If you are confused by the wall of text, join the discord here: [EveryDream Discord](https://discord.gg/uheqxU6sXN)\n",
"\n",
"### Requirements\n",
"Select the `RunPod Stable Diffusion v2.1` template. The `RunPod PyTorch` template does not work due to an old version of Python. \n",
"### Usage\n",
"\n",
"#### Storage\n",
"1. Prepare your training data before you begin (see below)\n",
"2. Spin the `RunPod Stable Diffusion v2.1` template. The `RunPod PyTorch` template does not work due to an old version of Python. \n",
"3. Open this notebook with `File > Open from URL...` pointing to `https://raw.githubusercontent.com/victorchall/EveryDream2trainer/main/Train_Runpod.ipynb`\n",
"4. Run each cell below once, noting any instructions above the cell (the first step requires a pod restart)\n",
"5. Figure out how you want to tweak the process next\n",
"6. Rinse, Repeat\n",
"\n",
"#### A note on storage\n",
"Remember, on RunPod time is more expensive than storage. \n",
"\n",
"Which is good, because running a lot of experiments can generate a lot of data. Not having the right save points to recover quickly from inevitable mistakes will cost you a lot of time.\n",
"\n",
"When in doubt, give yourself ~125GB of Runpod **Volume** storage.\n",
"\n",
"#### Preparation\n",
"#### Preparing your training data\n",
"You will want to have your data prepared before starting, and have a rough training plan in mind. Don't waste rental fees if you're not fully prepared to start training. \n",
"\n",
"**If this is your first time trying a full fine-tune, start small!** \n",
"Pick a single concept and 30-100 images, and see what happens. Training a small dataset like this is fast, and will give you a feel for how quickly your model (over-)trains depending on your training schedule.\n",
"\n",
"Your files should be captioned before you start with either the caption as the filename or in text files for each image alongside the image files. See [DATA.md](https://github.com/victorchall/EveryDream2trainer/blob/main/doc/DATA.md) for more details. Tools are available to automatically caption your files."
]
},
@ -88,7 +97,13 @@
"\n",
"If you have many training files, or nested folders of training data, create a zip archive of your training data, upload this file to the input folder, then right click on the zip file and select \"Extract Archive\".\n",
"\n",
"**While your training data is uploading, feel free to install the dependencies below**"
"## Optional - Configure sample prompts\n",
"You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt.\n",
"\n",
"Keep in mind a longer list of prompts will take longer to generate. You may also want to adjust you sample_steps in the training notebook to a different value to get samples left often. This is probably a good idea when training a larger dataset that you know will take longer to train, where more frequent samples will not help you.\n",
"\n",
"While your training data is uploading, go ahead to install the dependencies below\n",
"----"
]
},
{
@ -108,7 +123,10 @@
"cell_type": "code",
"execution_count": null,
"id": "9649a02c-fb2b-44f1-842d-d1662fa5c7cd",
"metadata": {},
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"!python -m pip install --upgrade pip\n",
@ -132,18 +150,6 @@
"print(\"DONE\")"
]
},
{
"cell_type": "markdown",
"id": "36105dbc-5a33-431b-b88e-b87d479d1ed7",
"metadata": {},
"source": [
"\n",
"# Optional - Configure sample prompts\n",
"You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt.\n",
"\n",
"Keep in mind a longer list of prompts will take longer to generate. You may also want to adjust you sample_steps in the training notebook to a different value to get samples left often. This is probably a good idea when training a larger dataset that you know will take longer to train, where more frequent samples will not help you."
]
},
{
"cell_type": "markdown",
"id": "c230d91a",
@ -181,7 +187,10 @@
"cell_type": "code",
"execution_count": null,
"id": "86b66fe4-c2ca-46fa-813c-8fe390813add",
"metadata": {},
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%cd /workspace/EveryDream2trainer\n",
@ -228,7 +237,10 @@
"cell_type": "code",
"execution_count": null,
"id": "6f73fb86-ebef-41e2-9382-4aa11be84be6",
"metadata": {},
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%cd /workspace/EveryDream2trainer\n",
@ -271,29 +283,11 @@
"metadata": {},
"source": [
"# HuggingFace upload\n",
"Use the cell below to upload your checkpoint to your personal HuggingFace account if you want instead of manually downloading. You should already be authorized to Huggingface by token if you used the download/token cells above.\n",
"\n",
"Make sure to fill in the three fields at the top. This will only upload one file at a time, so you will need to run the cell below for each file you want to upload.\n",
"Use the cell below to upload one or more checkpoints to your personal HuggingFace account, if you want, instead of manually downloading. You should already be authorized to Huggingface by token if you used the download/token cells above.\n",
"\n",
"* You can get your account name from your [HuggingFace account page](https://huggingface.co/settings/account). Look for your \"username\" field and paste it below.\n",
"\n",
"* You only need to setup a repository one time. You can create it here: [Create New HF Model](https://huggingface.co/new) Make sure you write down the repo name you make for future use. You can reuse it later.\n",
"\n",
"* You need to type the name of the ckpts one at a time in the cell below, find them in the left file drawer of your Runpod/Vast/Colab."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d5e8bd2-3e9f-4196-ad7a-ae6dcfc84b92",
"metadata": {},
"outputs": [],
"source": [
"#list ckpts in root that are ready for download\n",
"import glob\n",
"\n",
"for f in glob.glob(\"*.ckpt\"):\n",
" print(f)"
"* You only need to setup a repository one time. You can create it here: [Create New HF Model](https://huggingface.co/new) Make sure you write down the repo name you make for future use. You can reuse it later."
]
},
{
@ -303,32 +297,160 @@
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli lfs-enable-largefiles\n",
"# fill in these three fields:\n",
"hfusername = \"MyHfUser\"\n",
"reponame = \"MyHfRepo\"\n",
"ckpt_name = \"sd1_mymodel_000-ep100-gs01000.ckpt\"\n",
"\n",
"\n",
"target_name = ckpt_name.replace('-','').replace('=','')\n",
"import glob\n",
"import os\n",
"os.rename(ckpt_name,target_name)\n",
"repo_id=f\"{hfusername}/{reponame}\"\n",
"print(f\"uploading to HF: {repo_id}/{ckpt_name}\")\n",
"print(\"this make take a while...\")\n",
"\n",
"from huggingface_hub import HfApi\n",
"from ipywidgets import *\n",
"\n",
"all_ckpts = [f for f in glob.glob(\"*.ckpt\")]\n",
" \n",
"ckpt_picker = SelectMultiple(options=all_ckpts, layout=Layout(width=\"600px\")) \n",
"hfuser = Text(placeholder='Your HF user name')\n",
"hfrepo = Text(placeholder='Your HF repo name')\n",
"\n",
"api = HfApi()\n",
"response = api.upload_file(\n",
" path_or_fileobj=target_name,\n",
" path_in_repo=target_name,\n",
" repo_id=repo_id,\n",
" repo_type=None,\n",
" create_pr=1,\n",
"upload_btn = Button(description='Upload', layout=full_width)\n",
"out = Output()\n",
"\n",
"def upload_ckpts(_):\n",
" repo_id=f\"{hfuser.value}/{hfrepo.value}\"\n",
" with out:\n",
" for ckpt in ckpt_picker.value:\n",
" print(f\"Uploading to HF: huggingface.co/{repo_id}/{ckpt}\")\n",
" response = api.upload_file(\n",
" path_or_fileobj=ckpt,\n",
" path_in_repo=ckpt,\n",
" repo_id=repo_id,\n",
" repo_type=None,\n",
" create_pr=1,\n",
" )\n",
" display(response)\n",
" print(\"DONE\")\n",
" print(\"Go to your repo and accept the PRs this created to see your files\")\n",
"\n",
"upload_btn.on_click(upload_ckpts)\n",
"box = VBox([ckpt_picker, HBox([hfuser, hfrepo]), upload_btn, out])\n",
"\n",
"display(box)"
]
},
{
"cell_type": "markdown",
"id": "c1a00d16-9b84-492f-8e6a-defe71e82b43",
"metadata": {},
"source": [
"# Test inference on your checkpoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efb1a8cd-6a04-44e5-a770-c23ee247ce82",
"metadata": {},
"outputs": [],
"source": [
"%cd /workspace/EveryDream2trainer\n",
"from ipywidgets import *\n",
"from IPython.display import display, clear_output\n",
"import os\n",
"import gc\n",
"import random\n",
"import torch\n",
"import inspect\n",
"\n",
"from torch import autocast\n",
"from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler\n",
"from transformers import CLIPTextModel, CLIPTokenizer\n",
"\n",
"\n",
"checkpoints_ts = []\n",
"for root, dirs, files in os.walk(\".\"):\n",
" for file in files:\n",
" if os.path.basename(file) == \"model_index.json\":\n",
" ts = os.path.getmtime(os.path.join(root,file))\n",
" ckpt = root\n",
" checkpoints_ts.append((ts, root))\n",
"\n",
"checkpoints = [ckpt for (_, ckpt) in sorted(checkpoints_ts, reverse=True)]\n",
"full_width = Layout(width='600px')\n",
"half_width = Layout(width='300px')\n",
"\n",
"checkpoint = Dropdown(options=checkpoints, description='Checkpoint:', layout=full_width)\n",
"prompt = Textarea(value='a photo of ', description='Prompt:', layout=full_width)\n",
"height = IntSlider(value=512, min=256, max=768, step=32, description='Height:', layout=half_width)\n",
"width = IntSlider(value=512, min=256, max=768, step=32, description='Width:', layout=half_width)\n",
"cfg = FloatSlider(value=7.0, min=0.0, max=14.0, step=0.2, description='CFG Scale:', layout=half_width)\n",
"steps = IntSlider(value=30, min=10, max=100, description='Steps:', layout=half_width)\n",
"seed = IntText(value=-1, description='Seed:', layout=half_width)\n",
"generate_btn = Button(description='Generate', layout=full_width)\n",
"out = Output()\n",
"\n",
"def generate(_):\n",
" with out:\n",
" clear_output()\n",
" display(f\"Loading model {checkpoint.value}\")\n",
" actual_seed = seed.value if seed.value != -1 else random.randint(0, 2**30)\n",
"\n",
" text_encoder = CLIPTextModel.from_pretrained(checkpoint.value, subfolder=\"text_encoder\")\n",
" vae = AutoencoderKL.from_pretrained(checkpoint.value, subfolder=\"vae\")\n",
" unet = UNet2DConditionModel.from_pretrained(checkpoint.value, subfolder=\"unet\")\n",
" tokenizer = CLIPTokenizer.from_pretrained(checkpoint.value, subfolder=\"tokenizer\", use_fast=False)\n",
" scheduler = DDIMScheduler.from_pretrained(checkpoint.value, subfolder=\"scheduler\")\n",
" text_encoder.eval()\n",
" vae.eval()\n",
" unet.eval()\n",
"\n",
" text_encoder.to(\"cuda\")\n",
" vae.to(\"cuda\")\n",
" unet.to(\"cuda\")\n",
"\n",
" pipe = StableDiffusionPipeline(\n",
" vae=vae,\n",
" text_encoder=text_encoder,\n",
" tokenizer=tokenizer,\n",
" unet=unet,\n",
" scheduler=scheduler,\n",
" safety_checker=None, # save vram\n",
" requires_safety_checker=None, # avoid nag\n",
" feature_extractor=None, # must be none of no safety checker\n",
" )\n",
"\n",
" pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)\n",
" \n",
" print(inspect.cleandoc(f\"\"\"\n",
" Prompt: {prompt.value}\n",
" Resolution: {width.value}x{height.value}\n",
" CFG: {cfg.value}\n",
" Steps: {steps.value}\n",
" Seed: {actual_seed}\n",
" \"\"\"))\n",
" with autocast(\"cuda\"):\n",
" image = pipe(prompt.value, \n",
" generator=torch.Generator(\"cuda\").manual_seed(actual_seed),\n",
" num_inference_steps=steps.value, \n",
" guidance_scale=cfg.value,\n",
" width=width.value,\n",
" height=height.value\n",
" ).images[0]\n",
" del pipe\n",
" gc.collect()\n",
" with torch.cuda.device(\"cuda\"):\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.ipc_collect()\n",
" display(image)\n",
" \n",
"generate_btn.on_click(generate)\n",
"box = VBox(\n",
" children=[\n",
" checkpoint, prompt, \n",
" HBox([VBox([width, height]), VBox([steps, cfg])]), \n",
" seed, \n",
" generate_btn, \n",
" out]\n",
")\n",
"print(response)\n",
"print(finish_msg)\n",
"print(\"go to your repo and accept the PR this created to see your file\")"
"\n",
"\n",
"display(box)"
]
}
],