stable-diffusion-paperspace/other/Huggingface Textual Inversi...

557 lines
22 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0cf778b6-9d7f-45ed-8838-93df28213e7a",
"metadata": {},
"outputs": [],
"source": [
"input_images_dir = '/notebooks/textual inversion/source/' # upload your images here\n",
"hf_auth_token = ''"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9a306d5-a088-4fc1-800c-1e1ef57cd75c",
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p \"{input_images_dir}\"\n",
"# Install the required libs\n",
"!pip install diffusers transformers ftfy\n",
"!pip install \"ipywidgets>=7,<8\"\n",
"!pip install --upgrade diffusers transformers scipy\n",
"!pip install accelerate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48dcdb96-e9a6-4141-8442-5eba0280c31a",
"metadata": {},
"outputs": [],
"source": [
"# Import required libraries\n",
"import argparse\n",
"import itertools\n",
"import math\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torch.utils.checkpoint\n",
"from torch.utils.data import Dataset\n",
"\n",
"import PIL\n",
"from accelerate import Accelerator\n",
"from accelerate.logging import get_logger\n",
"from accelerate.utils import set_seed\n",
"from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n",
"from diffusers.hub_utils import init_git_repo, push_to_hub\n",
"from diffusers.optimization import get_scheduler\n",
"from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n",
"from PIL import Image\n",
"from torchvision import transforms\n",
"from tqdm.auto import tqdm\n",
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
"\n",
"def image_grid(imgs, rows, cols):\n",
" assert len(imgs) == rows*cols\n",
"\n",
" w, h = imgs[0].size\n",
" grid = Image.new('RGB', size=(cols*w, rows*h))\n",
" grid_w, grid_h = grid.size\n",
" \n",
" for i, img in enumerate(imgs):\n",
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
" return grid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "690c3700-0386-44cf-b8d2-1c9a8b85e296",
"metadata": {},
"outputs": [],
"source": [
"# Settings for your newly created concept\n",
"\n",
"# What is it that you are teaching?\n",
"# object enables you to teach the model a new object to be used\n",
"# style allows you to teach the model a new style one can use\n",
"what_to_teach = \"style\" # [\"object\", \"style\"]\n",
"\n",
"# `placeholder_token` is the token you are going to use to represent your new concept\n",
"# (so when you prompt the model, you will say \"A `<my-placeholder-token>` in an amusement park\")\n",
"# We use angle brackets to differentiate a token from other words/tokens, to avoid collision.\n",
"placeholder_token = \"<banana>\" #@param {type:\"string\"}\n",
"\n",
"#`initializer_token` is a word that can summarise what your new concept is, to be used as a starting point\n",
"initializer_token = \"art\" #@param {type:\"string\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e866762e-d3e8-4977-99ad-d742114f2fe6",
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"from io import BytesIO\n",
"\n",
"def absoluteFilePaths(directory):\n",
" x = []\n",
" for dirpath,_,filenames in os.walk(directory):\n",
" for f in filenames:\n",
" x.append(os.path.abspath(os.path.join(dirpath, f)))\n",
" return x\n",
"\n",
"def download_image(url):\n",
" return Image.open(url).convert(\"RGB\")\n",
"\n",
"save_path = '/notebooks/textual inversion/input/'\n",
"!rm -rf \"{save_path}\"\n",
"!mkdir -p \"{save_path}\"\n",
"\n",
"images = list(filter(None,[download_image(url) for url in absoluteFilePaths(input_images_dir)]))\n",
"[image.save(f\"{save_path}/{i}.jpeg\") for i, image in enumerate(images)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a882b1c8-97e3-4487-8300-083aa359ad1d",
"metadata": {},
"outputs": [],
"source": [
"#@title Setup the prompt templates for training\n",
"imagenet_templates_small = [\n",
" \"a photo of a {}\",\n",
" \"a rendering of a {}\",\n",
" \"a cropped photo of the {}\",\n",
" \"the photo of a {}\",\n",
" \"a photo of a clean {}\",\n",
" \"a photo of a dirty {}\",\n",
" \"a dark photo of the {}\",\n",
" \"a photo of my {}\",\n",
" \"a photo of the cool {}\",\n",
" \"a close-up photo of a {}\",\n",
" \"a bright photo of the {}\",\n",
" \"a cropped photo of a {}\",\n",
" \"a photo of the {}\",\n",
" \"a good photo of the {}\",\n",
" \"a photo of one {}\",\n",
" \"a close-up photo of the {}\",\n",
" \"a rendition of the {}\",\n",
" \"a photo of the clean {}\",\n",
" \"a rendition of a {}\",\n",
" \"a photo of a nice {}\",\n",
" \"a good photo of a {}\",\n",
" \"a photo of the nice {}\",\n",
" \"a photo of the small {}\",\n",
" \"a photo of the weird {}\",\n",
" \"a photo of the large {}\",\n",
" \"a photo of a cool {}\",\n",
" \"a photo of a small {}\",\n",
"]\n",
"\n",
"imagenet_style_templates_small = [\n",
" \"a painting in the style of {}\",\n",
" \"a rendering in the style of {}\",\n",
" \"a cropped painting in the style of {}\",\n",
" \"the painting in the style of {}\",\n",
" \"a clean painting in the style of {}\",\n",
" \"a dirty painting in the style of {}\",\n",
" \"a dark painting in the style of {}\",\n",
" \"a picture in the style of {}\",\n",
" \"a cool painting in the style of {}\",\n",
" \"a close-up painting in the style of {}\",\n",
" \"a bright painting in the style of {}\",\n",
" \"a cropped painting in the style of {}\",\n",
" \"a good painting in the style of {}\",\n",
" \"a close-up painting in the style of {}\",\n",
" \"a rendition in the style of {}\",\n",
" \"a nice painting in the style of {}\",\n",
" \"a small painting in the style of {}\",\n",
" \"a weird painting in the style of {}\",\n",
" \"a large painting in the style of {}\",\n",
"]\n",
"#@title Setup the dataset\n",
"pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'\n",
"class TextualInversionDataset(Dataset):\n",
" def __init__(\n",
" self,\n",
" data_root,\n",
" tokenizer,\n",
" learnable_property=\"object\", # [object, style]\n",
" size=512,\n",
" repeats=100,\n",
" interpolation=\"bicubic\",\n",
" flip_p=0.5,\n",
" set=\"train\",\n",
" placeholder_token=\"*\",\n",
" center_crop=False,\n",
" ):\n",
"\n",
" self.data_root = data_root\n",
" self.tokenizer = tokenizer\n",
" self.learnable_property = learnable_property\n",
" self.size = size\n",
" self.placeholder_token = placeholder_token\n",
" self.center_crop = center_crop\n",
" self.flip_p = flip_p\n",
"\n",
" self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n",
"\n",
" self.num_images = len(self.image_paths)\n",
" self._length = self.num_images\n",
"\n",
" if set == \"train\":\n",
" self._length = self.num_images * repeats\n",
"\n",
" self.interpolation = {\n",
" \"linear\": PIL.Image.LINEAR,\n",
" \"bilinear\": PIL.Image.BILINEAR,\n",
" \"bicubic\": PIL.Image.BICUBIC,\n",
" \"lanczos\": PIL.Image.LANCZOS,\n",
" }[interpolation]\n",
"\n",
" self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n",
" self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n",
"\n",
" def __len__(self):\n",
" return self._length\n",
"\n",
" def __getitem__(self, i):\n",
" example = {}\n",
" image = Image.open(self.image_paths[i % self.num_images])\n",
"\n",
" if not image.mode == \"RGB\":\n",
" image = image.convert(\"RGB\")\n",
"\n",
" placeholder_string = self.placeholder_token\n",
" text = random.choice(self.templates).format(placeholder_string)\n",
"\n",
" example[\"input_ids\"] = self.tokenizer(\n",
" text,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" max_length=self.tokenizer.model_max_length,\n",
" return_tensors=\"pt\",\n",
" ).input_ids[0]\n",
"\n",
" # default to score-sde preprocessing\n",
" img = np.array(image).astype(np.uint8)\n",
"\n",
" if self.center_crop:\n",
" crop = min(img.shape[0], img.shape[1])\n",
" h, w, = (\n",
" img.shape[0],\n",
" img.shape[1],\n",
" )\n",
" img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n",
"\n",
" image = Image.fromarray(img)\n",
" image = image.resize((self.size, self.size), resample=self.interpolation)\n",
"\n",
" image = self.flip_transform(image)\n",
" image = np.array(image).astype(np.uint8)\n",
" image = (image / 127.5 - 1.0).astype(np.float32)\n",
"\n",
" example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n",
" return example\n",
" \n",
"#@title Load the tokenizer and add the placeholder token as a additional special token.\n",
"#@markdown Please read and if you agree accept the LICENSE [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error\n",
"tokenizer = CLIPTokenizer.from_pretrained(\n",
" pretrained_model_name_or_path,\n",
" subfolder=\"tokenizer\",\n",
" use_auth_token=hf_auth_token,\n",
")\n",
"\n",
"# Add the placeholder token in tokenizer\n",
"num_added_tokens = tokenizer.add_tokens(placeholder_token)\n",
"if num_added_tokens == 0:\n",
" raise ValueError(\n",
" f\"The tokenizer already contains the token {placeholder_token}. Please pass a different\"\n",
" \" `placeholder_token` that is not already in the tokenizer.\"\n",
" )\n",
" \n",
"#@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token\n",
"# Convert the initializer_token, placeholder_token to ids\n",
"token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)\n",
"# Check if initializer_token is a single token or a sequence of tokens\n",
"if len(token_ids) > 1:\n",
" raise ValueError(\"The initializer token must be a single token.\")\n",
"\n",
"initializer_token_id = token_ids[0]\n",
"placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b8daf19-a6fa-4977-b176-d7d5200c81a1",
"metadata": {},
"outputs": [],
"source": [
"#@title Load the Stable Diffusion model\n",
"# Load models and create wrapper for stable diffusion\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\", use_auth_token=hf_auth_token\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\", use_auth_token=hf_auth_token\n",
")\n",
"unet = UNet2DConditionModel.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"unet\", use_auth_token=hf_auth_token\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "134c7b49-e8e8-4d10-9c96-018dd8123a2a",
"metadata": {},
"outputs": [],
"source": [
"text_encoder.resize_token_embeddings(len(tokenizer))\n",
"token_embeds = text_encoder.get_input_embeddings().weight.data\n",
"token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]\n",
"def freeze_params(params):\n",
" for param in params:\n",
" param.requires_grad = False\n",
"\n",
"# Freeze vae and unet\n",
"freeze_params(vae.parameters())\n",
"freeze_params(unet.parameters())\n",
"# Freeze all parameters except for the token embeddings in text encoder\n",
"params_to_freeze = itertools.chain(\n",
" text_encoder.text_model.encoder.parameters(),\n",
" text_encoder.text_model.final_layer_norm.parameters(),\n",
" text_encoder.text_model.embeddings.position_embedding.parameters(),\n",
")\n",
"freeze_params(params_to_freeze)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ceb1fff-25c3-4c9a-ad5d-45fc9ab8189f",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = TextualInversionDataset(\n",
" data_root=save_path,\n",
" tokenizer=tokenizer,\n",
" size=512,\n",
" placeholder_token=placeholder_token,\n",
" repeats=100,\n",
" learnable_property=what_to_teach, #Option selected above between object and style\n",
" center_crop=False,\n",
" set=\"train\",\n",
")\n",
"def create_dataloader(train_batch_size=1):\n",
" return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)\n",
"noise_scheduler = DDPMScheduler(\n",
" beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000, tensor_format=\"pt\"\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a511581b-a610-4f86-be49-a94529a456dc",
"metadata": {},
"outputs": [],
"source": [
"# If you are not happy with your results, you can tune the `learning_rate` and the `max_train_steps`\n",
"hyperparameters = {\n",
" \"learning_rate\": 5e-04,\n",
" \"scale_lr\": True,\n",
" \"max_train_steps\": 8500,\n",
" \"train_batch_size\": 1,\n",
" \"gradient_accumulation_steps\": 4,\n",
" \"seed\": 42,\n",
" \"output_dir\": \"sd-concept-output\"\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5c9f8e1-c2da-4985-9f8c-fa5811d7ad20",
"metadata": {},
"outputs": [],
"source": [
"def training_function(text_encoder, vae, unet):\n",
" logger = get_logger(__name__)\n",
"\n",
" train_batch_size = hyperparameters[\"train_batch_size\"]\n",
" gradient_accumulation_steps = hyperparameters[\"gradient_accumulation_steps\"]\n",
" learning_rate = hyperparameters[\"learning_rate\"]\n",
" max_train_steps = hyperparameters[\"max_train_steps\"]\n",
" output_dir = hyperparameters[\"output_dir\"]\n",
"\n",
" accelerator = Accelerator(\n",
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
" )\n",
"\n",
" train_dataloader = create_dataloader(train_batch_size)\n",
"\n",
" if hyperparameters[\"scale_lr\"]:\n",
" learning_rate = (\n",
" learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes\n",
" )\n",
"\n",
" # Initialize the optimizer\n",
" optimizer = torch.optim.AdamW(\n",
" text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings\n",
" lr=learning_rate,\n",
" )\n",
"\n",
"\n",
" text_encoder, optimizer, train_dataloader = accelerator.prepare(\n",
" text_encoder, optimizer, train_dataloader\n",
" )\n",
"\n",
" # Move vae and unet to device\n",
" vae.to(accelerator.device)\n",
" unet.to(accelerator.device)\n",
"\n",
" # Keep vae and unet in eval model as we don't train these\n",
" vae.eval()\n",
" unet.eval()\n",
"\n",
" # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
" num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)\n",
" num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
"\n",
" # Train!\n",
" total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps\n",
"\n",
" logger.info(\"***** Running training *****\")\n",
" logger.info(f\" Num examples = {len(train_dataset)}\")\n",
" logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
" logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
" logger.info(f\" Gradient Accumulation steps = {gradient_accumulation_steps}\")\n",
" logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
" # Only show the progress bar once on each machine.\n",
" progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)\n",
" progress_bar.set_description(\"Steps\")\n",
" global_step = 0\n",
"\n",
" for epoch in range(num_train_epochs):\n",
" text_encoder.train()\n",
" for step, batch in enumerate(train_dataloader):\n",
" with accelerator.accumulate(text_encoder):\n",
" # Convert images to latent space\n",
" latents = vae.encode(batch[\"pixel_values\"]).latent_dist.sample().detach()\n",
" latents = latents * 0.18215\n",
"\n",
" # Sample noise that we'll add to the latents\n",
" noise = torch.randn(latents.shape).to(latents.device)\n",
" bsz = latents.shape[0]\n",
" # Sample a random timestep for each image\n",
" timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()\n",
"\n",
" # Add noise to the latents according to the noise magnitude at each timestep\n",
" # (this is the forward diffusion process)\n",
" noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n",
"\n",
" # Get the text embedding for conditioning\n",
" encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n",
"\n",
" # Predict the noise residual\n",
" noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n",
"\n",
" loss = F.mse_loss(noise_pred, noise, reduction=\"none\").mean([1, 2, 3]).mean()\n",
" accelerator.backward(loss)\n",
"\n",
" # Zero out the gradients for all token embeddings except the newly added\n",
" # embeddings for the concept, as we only want to optimize the concept embeddings\n",
" if accelerator.num_processes > 1:\n",
" grads = text_encoder.module.get_input_embeddings().weight.grad\n",
" else:\n",
" grads = text_encoder.get_input_embeddings().weight.grad\n",
" # Get the index for tokens that we want to zero the grads for\n",
" index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id\n",
" grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)\n",
"\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" # Checks if the accelerator has performed an optimization step behind the scenes\n",
" if accelerator.sync_gradients:\n",
" progress_bar.update(1)\n",
" global_step += 1\n",
"\n",
" logs = {\"loss\": loss.detach().item()}\n",
" progress_bar.set_postfix(**logs)\n",
"\n",
" if global_step >= max_train_steps:\n",
" break\n",
"\n",
" accelerator.wait_for_everyone()\n",
"\n",
"\n",
" # Create the pipeline using using the trained modules and save it.\n",
" if accelerator.is_main_process:\n",
" pipeline = StableDiffusionPipeline(\n",
" text_encoder=accelerator.unwrap_model(text_encoder),\n",
" vae=vae,\n",
" unet=unet,\n",
" tokenizer=tokenizer,\n",
" scheduler=PNDMScheduler(\n",
" beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", skip_prk_steps=True\n",
" ),\n",
" safety_checker=StableDiffusionSafetyChecker.from_pretrained(\"CompVis/stable-diffusion-safety-checker\"),\n",
" feature_extractor=CLIPFeatureExtractor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n",
" )\n",
" pipeline.save_pretrained(output_dir)\n",
" # Also save the newly trained embeddings\n",
" learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]\n",
" learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}\n",
" torch.save(learned_embeds_dict, os.path.join(output_dir, \"learned_embeds.bin\"))\n",
"\n",
"import accelerate\n",
"accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet), num_processes=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fda9c1eb-8a2d-4da9-a98c-8c6e717eb215",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}