557 lines
22 KiB
Plaintext
557 lines
22 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "0cf778b6-9d7f-45ed-8838-93df28213e7a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"input_images_dir = '/notebooks/textual inversion/source/' # upload your images here\n",
|
||
|
"hf_auth_token = ''"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "c9a306d5-a088-4fc1-800c-1e1ef57cd75c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"!mkdir -p \"{input_images_dir}\"\n",
|
||
|
"# Install the required libs\n",
|
||
|
"!pip install diffusers transformers ftfy\n",
|
||
|
"!pip install \"ipywidgets>=7,<8\"\n",
|
||
|
"!pip install --upgrade diffusers transformers scipy\n",
|
||
|
"!pip install accelerate"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "48dcdb96-e9a6-4141-8442-5eba0280c31a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Import required libraries\n",
|
||
|
"import argparse\n",
|
||
|
"import itertools\n",
|
||
|
"import math\n",
|
||
|
"import os\n",
|
||
|
"import random\n",
|
||
|
"\n",
|
||
|
"import numpy as np\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"import torch.utils.checkpoint\n",
|
||
|
"from torch.utils.data import Dataset\n",
|
||
|
"\n",
|
||
|
"import PIL\n",
|
||
|
"from accelerate import Accelerator\n",
|
||
|
"from accelerate.logging import get_logger\n",
|
||
|
"from accelerate.utils import set_seed\n",
|
||
|
"from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n",
|
||
|
"from diffusers.hub_utils import init_git_repo, push_to_hub\n",
|
||
|
"from diffusers.optimization import get_scheduler\n",
|
||
|
"from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n",
|
||
|
"from PIL import Image\n",
|
||
|
"from torchvision import transforms\n",
|
||
|
"from tqdm.auto import tqdm\n",
|
||
|
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
|
||
|
"\n",
|
||
|
"def image_grid(imgs, rows, cols):\n",
|
||
|
" assert len(imgs) == rows*cols\n",
|
||
|
"\n",
|
||
|
" w, h = imgs[0].size\n",
|
||
|
" grid = Image.new('RGB', size=(cols*w, rows*h))\n",
|
||
|
" grid_w, grid_h = grid.size\n",
|
||
|
" \n",
|
||
|
" for i, img in enumerate(imgs):\n",
|
||
|
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
|
||
|
" return grid"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "690c3700-0386-44cf-b8d2-1c9a8b85e296",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Settings for your newly created concept\n",
|
||
|
"\n",
|
||
|
"# What is it that you are teaching?\n",
|
||
|
"# object enables you to teach the model a new object to be used\n",
|
||
|
"# style allows you to teach the model a new style one can use\n",
|
||
|
"what_to_teach = \"style\" # [\"object\", \"style\"]\n",
|
||
|
"\n",
|
||
|
"# `placeholder_token` is the token you are going to use to represent your new concept\n",
|
||
|
"# (so when you prompt the model, you will say \"A `<my-placeholder-token>` in an amusement park\")\n",
|
||
|
"# We use angle brackets to differentiate a token from other words/tokens, to avoid collision.\n",
|
||
|
"placeholder_token = \"<banana>\" #@param {type:\"string\"}\n",
|
||
|
"\n",
|
||
|
"#`initializer_token` is a word that can summarise what your new concept is, to be used as a starting point\n",
|
||
|
"initializer_token = \"art\" #@param {type:\"string\"}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "e866762e-d3e8-4977-99ad-d742114f2fe6",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import glob\n",
|
||
|
"from io import BytesIO\n",
|
||
|
"\n",
|
||
|
"def absoluteFilePaths(directory):\n",
|
||
|
" x = []\n",
|
||
|
" for dirpath,_,filenames in os.walk(directory):\n",
|
||
|
" for f in filenames:\n",
|
||
|
" x.append(os.path.abspath(os.path.join(dirpath, f)))\n",
|
||
|
" return x\n",
|
||
|
"\n",
|
||
|
"def download_image(url):\n",
|
||
|
" return Image.open(url).convert(\"RGB\")\n",
|
||
|
"\n",
|
||
|
"save_path = '/notebooks/textual inversion/input/'\n",
|
||
|
"!rm -rf \"{save_path}\"\n",
|
||
|
"!mkdir -p \"{save_path}\"\n",
|
||
|
"\n",
|
||
|
"images = list(filter(None,[download_image(url) for url in absoluteFilePaths(input_images_dir)]))\n",
|
||
|
"[image.save(f\"{save_path}/{i}.jpeg\") for i, image in enumerate(images)]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a882b1c8-97e3-4487-8300-083aa359ad1d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#@title Setup the prompt templates for training\n",
|
||
|
"imagenet_templates_small = [\n",
|
||
|
" \"a photo of a {}\",\n",
|
||
|
" \"a rendering of a {}\",\n",
|
||
|
" \"a cropped photo of the {}\",\n",
|
||
|
" \"the photo of a {}\",\n",
|
||
|
" \"a photo of a clean {}\",\n",
|
||
|
" \"a photo of a dirty {}\",\n",
|
||
|
" \"a dark photo of the {}\",\n",
|
||
|
" \"a photo of my {}\",\n",
|
||
|
" \"a photo of the cool {}\",\n",
|
||
|
" \"a close-up photo of a {}\",\n",
|
||
|
" \"a bright photo of the {}\",\n",
|
||
|
" \"a cropped photo of a {}\",\n",
|
||
|
" \"a photo of the {}\",\n",
|
||
|
" \"a good photo of the {}\",\n",
|
||
|
" \"a photo of one {}\",\n",
|
||
|
" \"a close-up photo of the {}\",\n",
|
||
|
" \"a rendition of the {}\",\n",
|
||
|
" \"a photo of the clean {}\",\n",
|
||
|
" \"a rendition of a {}\",\n",
|
||
|
" \"a photo of a nice {}\",\n",
|
||
|
" \"a good photo of a {}\",\n",
|
||
|
" \"a photo of the nice {}\",\n",
|
||
|
" \"a photo of the small {}\",\n",
|
||
|
" \"a photo of the weird {}\",\n",
|
||
|
" \"a photo of the large {}\",\n",
|
||
|
" \"a photo of a cool {}\",\n",
|
||
|
" \"a photo of a small {}\",\n",
|
||
|
"]\n",
|
||
|
"\n",
|
||
|
"imagenet_style_templates_small = [\n",
|
||
|
" \"a painting in the style of {}\",\n",
|
||
|
" \"a rendering in the style of {}\",\n",
|
||
|
" \"a cropped painting in the style of {}\",\n",
|
||
|
" \"the painting in the style of {}\",\n",
|
||
|
" \"a clean painting in the style of {}\",\n",
|
||
|
" \"a dirty painting in the style of {}\",\n",
|
||
|
" \"a dark painting in the style of {}\",\n",
|
||
|
" \"a picture in the style of {}\",\n",
|
||
|
" \"a cool painting in the style of {}\",\n",
|
||
|
" \"a close-up painting in the style of {}\",\n",
|
||
|
" \"a bright painting in the style of {}\",\n",
|
||
|
" \"a cropped painting in the style of {}\",\n",
|
||
|
" \"a good painting in the style of {}\",\n",
|
||
|
" \"a close-up painting in the style of {}\",\n",
|
||
|
" \"a rendition in the style of {}\",\n",
|
||
|
" \"a nice painting in the style of {}\",\n",
|
||
|
" \"a small painting in the style of {}\",\n",
|
||
|
" \"a weird painting in the style of {}\",\n",
|
||
|
" \"a large painting in the style of {}\",\n",
|
||
|
"]\n",
|
||
|
"#@title Setup the dataset\n",
|
||
|
"pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'\n",
|
||
|
"class TextualInversionDataset(Dataset):\n",
|
||
|
" def __init__(\n",
|
||
|
" self,\n",
|
||
|
" data_root,\n",
|
||
|
" tokenizer,\n",
|
||
|
" learnable_property=\"object\", # [object, style]\n",
|
||
|
" size=512,\n",
|
||
|
" repeats=100,\n",
|
||
|
" interpolation=\"bicubic\",\n",
|
||
|
" flip_p=0.5,\n",
|
||
|
" set=\"train\",\n",
|
||
|
" placeholder_token=\"*\",\n",
|
||
|
" center_crop=False,\n",
|
||
|
" ):\n",
|
||
|
"\n",
|
||
|
" self.data_root = data_root\n",
|
||
|
" self.tokenizer = tokenizer\n",
|
||
|
" self.learnable_property = learnable_property\n",
|
||
|
" self.size = size\n",
|
||
|
" self.placeholder_token = placeholder_token\n",
|
||
|
" self.center_crop = center_crop\n",
|
||
|
" self.flip_p = flip_p\n",
|
||
|
"\n",
|
||
|
" self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n",
|
||
|
"\n",
|
||
|
" self.num_images = len(self.image_paths)\n",
|
||
|
" self._length = self.num_images\n",
|
||
|
"\n",
|
||
|
" if set == \"train\":\n",
|
||
|
" self._length = self.num_images * repeats\n",
|
||
|
"\n",
|
||
|
" self.interpolation = {\n",
|
||
|
" \"linear\": PIL.Image.LINEAR,\n",
|
||
|
" \"bilinear\": PIL.Image.BILINEAR,\n",
|
||
|
" \"bicubic\": PIL.Image.BICUBIC,\n",
|
||
|
" \"lanczos\": PIL.Image.LANCZOS,\n",
|
||
|
" }[interpolation]\n",
|
||
|
"\n",
|
||
|
" self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n",
|
||
|
" self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return self._length\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, i):\n",
|
||
|
" example = {}\n",
|
||
|
" image = Image.open(self.image_paths[i % self.num_images])\n",
|
||
|
"\n",
|
||
|
" if not image.mode == \"RGB\":\n",
|
||
|
" image = image.convert(\"RGB\")\n",
|
||
|
"\n",
|
||
|
" placeholder_string = self.placeholder_token\n",
|
||
|
" text = random.choice(self.templates).format(placeholder_string)\n",
|
||
|
"\n",
|
||
|
" example[\"input_ids\"] = self.tokenizer(\n",
|
||
|
" text,\n",
|
||
|
" padding=\"max_length\",\n",
|
||
|
" truncation=True,\n",
|
||
|
" max_length=self.tokenizer.model_max_length,\n",
|
||
|
" return_tensors=\"pt\",\n",
|
||
|
" ).input_ids[0]\n",
|
||
|
"\n",
|
||
|
" # default to score-sde preprocessing\n",
|
||
|
" img = np.array(image).astype(np.uint8)\n",
|
||
|
"\n",
|
||
|
" if self.center_crop:\n",
|
||
|
" crop = min(img.shape[0], img.shape[1])\n",
|
||
|
" h, w, = (\n",
|
||
|
" img.shape[0],\n",
|
||
|
" img.shape[1],\n",
|
||
|
" )\n",
|
||
|
" img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n",
|
||
|
"\n",
|
||
|
" image = Image.fromarray(img)\n",
|
||
|
" image = image.resize((self.size, self.size), resample=self.interpolation)\n",
|
||
|
"\n",
|
||
|
" image = self.flip_transform(image)\n",
|
||
|
" image = np.array(image).astype(np.uint8)\n",
|
||
|
" image = (image / 127.5 - 1.0).astype(np.float32)\n",
|
||
|
"\n",
|
||
|
" example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n",
|
||
|
" return example\n",
|
||
|
" \n",
|
||
|
"#@title Load the tokenizer and add the placeholder token as a additional special token.\n",
|
||
|
"#@markdown Please read and if you agree accept the LICENSE [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error\n",
|
||
|
"tokenizer = CLIPTokenizer.from_pretrained(\n",
|
||
|
" pretrained_model_name_or_path,\n",
|
||
|
" subfolder=\"tokenizer\",\n",
|
||
|
" use_auth_token=hf_auth_token,\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"# Add the placeholder token in tokenizer\n",
|
||
|
"num_added_tokens = tokenizer.add_tokens(placeholder_token)\n",
|
||
|
"if num_added_tokens == 0:\n",
|
||
|
" raise ValueError(\n",
|
||
|
" f\"The tokenizer already contains the token {placeholder_token}. Please pass a different\"\n",
|
||
|
" \" `placeholder_token` that is not already in the tokenizer.\"\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
"#@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token\n",
|
||
|
"# Convert the initializer_token, placeholder_token to ids\n",
|
||
|
"token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)\n",
|
||
|
"# Check if initializer_token is a single token or a sequence of tokens\n",
|
||
|
"if len(token_ids) > 1:\n",
|
||
|
" raise ValueError(\"The initializer token must be a single token.\")\n",
|
||
|
"\n",
|
||
|
"initializer_token_id = token_ids[0]\n",
|
||
|
"placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "2b8daf19-a6fa-4977-b176-d7d5200c81a1",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#@title Load the Stable Diffusion model\n",
|
||
|
"# Load models and create wrapper for stable diffusion\n",
|
||
|
"text_encoder = CLIPTextModel.from_pretrained(\n",
|
||
|
" pretrained_model_name_or_path, subfolder=\"text_encoder\", use_auth_token=hf_auth_token\n",
|
||
|
")\n",
|
||
|
"vae = AutoencoderKL.from_pretrained(\n",
|
||
|
" pretrained_model_name_or_path, subfolder=\"vae\", use_auth_token=hf_auth_token\n",
|
||
|
")\n",
|
||
|
"unet = UNet2DConditionModel.from_pretrained(\n",
|
||
|
" pretrained_model_name_or_path, subfolder=\"unet\", use_auth_token=hf_auth_token\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "134c7b49-e8e8-4d10-9c96-018dd8123a2a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"text_encoder.resize_token_embeddings(len(tokenizer))\n",
|
||
|
"token_embeds = text_encoder.get_input_embeddings().weight.data\n",
|
||
|
"token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]\n",
|
||
|
"def freeze_params(params):\n",
|
||
|
" for param in params:\n",
|
||
|
" param.requires_grad = False\n",
|
||
|
"\n",
|
||
|
"# Freeze vae and unet\n",
|
||
|
"freeze_params(vae.parameters())\n",
|
||
|
"freeze_params(unet.parameters())\n",
|
||
|
"# Freeze all parameters except for the token embeddings in text encoder\n",
|
||
|
"params_to_freeze = itertools.chain(\n",
|
||
|
" text_encoder.text_model.encoder.parameters(),\n",
|
||
|
" text_encoder.text_model.final_layer_norm.parameters(),\n",
|
||
|
" text_encoder.text_model.embeddings.position_embedding.parameters(),\n",
|
||
|
")\n",
|
||
|
"freeze_params(params_to_freeze)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "5ceb1fff-25c3-4c9a-ad5d-45fc9ab8189f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_dataset = TextualInversionDataset(\n",
|
||
|
" data_root=save_path,\n",
|
||
|
" tokenizer=tokenizer,\n",
|
||
|
" size=512,\n",
|
||
|
" placeholder_token=placeholder_token,\n",
|
||
|
" repeats=100,\n",
|
||
|
" learnable_property=what_to_teach, #Option selected above between object and style\n",
|
||
|
" center_crop=False,\n",
|
||
|
" set=\"train\",\n",
|
||
|
")\n",
|
||
|
"def create_dataloader(train_batch_size=1):\n",
|
||
|
" return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)\n",
|
||
|
"noise_scheduler = DDPMScheduler(\n",
|
||
|
" beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000, tensor_format=\"pt\"\n",
|
||
|
")\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a511581b-a610-4f86-be49-a94529a456dc",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# If you are not happy with your results, you can tune the `learning_rate` and the `max_train_steps`\n",
|
||
|
"hyperparameters = {\n",
|
||
|
" \"learning_rate\": 5e-04,\n",
|
||
|
" \"scale_lr\": True,\n",
|
||
|
" \"max_train_steps\": 8500,\n",
|
||
|
" \"train_batch_size\": 1,\n",
|
||
|
" \"gradient_accumulation_steps\": 4,\n",
|
||
|
" \"seed\": 42,\n",
|
||
|
" \"output_dir\": \"sd-concept-output\"\n",
|
||
|
"}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "e5c9f8e1-c2da-4985-9f8c-fa5811d7ad20",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def training_function(text_encoder, vae, unet):\n",
|
||
|
" logger = get_logger(__name__)\n",
|
||
|
"\n",
|
||
|
" train_batch_size = hyperparameters[\"train_batch_size\"]\n",
|
||
|
" gradient_accumulation_steps = hyperparameters[\"gradient_accumulation_steps\"]\n",
|
||
|
" learning_rate = hyperparameters[\"learning_rate\"]\n",
|
||
|
" max_train_steps = hyperparameters[\"max_train_steps\"]\n",
|
||
|
" output_dir = hyperparameters[\"output_dir\"]\n",
|
||
|
"\n",
|
||
|
" accelerator = Accelerator(\n",
|
||
|
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" train_dataloader = create_dataloader(train_batch_size)\n",
|
||
|
"\n",
|
||
|
" if hyperparameters[\"scale_lr\"]:\n",
|
||
|
" learning_rate = (\n",
|
||
|
" learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # Initialize the optimizer\n",
|
||
|
" optimizer = torch.optim.AdamW(\n",
|
||
|
" text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings\n",
|
||
|
" lr=learning_rate,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" text_encoder, optimizer, train_dataloader = accelerator.prepare(\n",
|
||
|
" text_encoder, optimizer, train_dataloader\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # Move vae and unet to device\n",
|
||
|
" vae.to(accelerator.device)\n",
|
||
|
" unet.to(accelerator.device)\n",
|
||
|
"\n",
|
||
|
" # Keep vae and unet in eval model as we don't train these\n",
|
||
|
" vae.eval()\n",
|
||
|
" unet.eval()\n",
|
||
|
"\n",
|
||
|
" # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
|
||
|
" num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)\n",
|
||
|
" num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
|
||
|
"\n",
|
||
|
" # Train!\n",
|
||
|
" total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps\n",
|
||
|
"\n",
|
||
|
" logger.info(\"***** Running training *****\")\n",
|
||
|
" logger.info(f\" Num examples = {len(train_dataset)}\")\n",
|
||
|
" logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
|
||
|
" logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
|
||
|
" logger.info(f\" Gradient Accumulation steps = {gradient_accumulation_steps}\")\n",
|
||
|
" logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
|
||
|
" # Only show the progress bar once on each machine.\n",
|
||
|
" progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)\n",
|
||
|
" progress_bar.set_description(\"Steps\")\n",
|
||
|
" global_step = 0\n",
|
||
|
"\n",
|
||
|
" for epoch in range(num_train_epochs):\n",
|
||
|
" text_encoder.train()\n",
|
||
|
" for step, batch in enumerate(train_dataloader):\n",
|
||
|
" with accelerator.accumulate(text_encoder):\n",
|
||
|
" # Convert images to latent space\n",
|
||
|
" latents = vae.encode(batch[\"pixel_values\"]).latent_dist.sample().detach()\n",
|
||
|
" latents = latents * 0.18215\n",
|
||
|
"\n",
|
||
|
" # Sample noise that we'll add to the latents\n",
|
||
|
" noise = torch.randn(latents.shape).to(latents.device)\n",
|
||
|
" bsz = latents.shape[0]\n",
|
||
|
" # Sample a random timestep for each image\n",
|
||
|
" timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()\n",
|
||
|
"\n",
|
||
|
" # Add noise to the latents according to the noise magnitude at each timestep\n",
|
||
|
" # (this is the forward diffusion process)\n",
|
||
|
" noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n",
|
||
|
"\n",
|
||
|
" # Get the text embedding for conditioning\n",
|
||
|
" encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n",
|
||
|
"\n",
|
||
|
" # Predict the noise residual\n",
|
||
|
" noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n",
|
||
|
"\n",
|
||
|
" loss = F.mse_loss(noise_pred, noise, reduction=\"none\").mean([1, 2, 3]).mean()\n",
|
||
|
" accelerator.backward(loss)\n",
|
||
|
"\n",
|
||
|
" # Zero out the gradients for all token embeddings except the newly added\n",
|
||
|
" # embeddings for the concept, as we only want to optimize the concept embeddings\n",
|
||
|
" if accelerator.num_processes > 1:\n",
|
||
|
" grads = text_encoder.module.get_input_embeddings().weight.grad\n",
|
||
|
" else:\n",
|
||
|
" grads = text_encoder.get_input_embeddings().weight.grad\n",
|
||
|
" # Get the index for tokens that we want to zero the grads for\n",
|
||
|
" index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id\n",
|
||
|
" grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)\n",
|
||
|
"\n",
|
||
|
" optimizer.step()\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
"\n",
|
||
|
" # Checks if the accelerator has performed an optimization step behind the scenes\n",
|
||
|
" if accelerator.sync_gradients:\n",
|
||
|
" progress_bar.update(1)\n",
|
||
|
" global_step += 1\n",
|
||
|
"\n",
|
||
|
" logs = {\"loss\": loss.detach().item()}\n",
|
||
|
" progress_bar.set_postfix(**logs)\n",
|
||
|
"\n",
|
||
|
" if global_step >= max_train_steps:\n",
|
||
|
" break\n",
|
||
|
"\n",
|
||
|
" accelerator.wait_for_everyone()\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" # Create the pipeline using using the trained modules and save it.\n",
|
||
|
" if accelerator.is_main_process:\n",
|
||
|
" pipeline = StableDiffusionPipeline(\n",
|
||
|
" text_encoder=accelerator.unwrap_model(text_encoder),\n",
|
||
|
" vae=vae,\n",
|
||
|
" unet=unet,\n",
|
||
|
" tokenizer=tokenizer,\n",
|
||
|
" scheduler=PNDMScheduler(\n",
|
||
|
" beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", skip_prk_steps=True\n",
|
||
|
" ),\n",
|
||
|
" safety_checker=StableDiffusionSafetyChecker.from_pretrained(\"CompVis/stable-diffusion-safety-checker\"),\n",
|
||
|
" feature_extractor=CLIPFeatureExtractor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n",
|
||
|
" )\n",
|
||
|
" pipeline.save_pretrained(output_dir)\n",
|
||
|
" # Also save the newly trained embeddings\n",
|
||
|
" learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]\n",
|
||
|
" learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}\n",
|
||
|
" torch.save(learned_embeds_dict, os.path.join(output_dir, \"learned_embeds.bin\"))\n",
|
||
|
"\n",
|
||
|
"import accelerate\n",
|
||
|
"accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet), num_processes=1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "fda9c1eb-8a2d-4da9-a98c-8c6e717eb215",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.13"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|