diff --git a/diffusers_trainer.py b/diffusers_trainer.py new file mode 100644 index 0000000..b780341 --- /dev/null +++ b/diffusers_trainer.py @@ -0,0 +1,636 @@ +import argparse +import socket +import torch +import torchvision +import transformers +import diffusers +import os +import glob +import random +import tqdm +import resource +import psutil +import pynvml +import wandb +import gc +import itertools +import numpy as np + +try: + pynvml.nvmlInit() +except pynvml.nvml.NVMLError_LibraryNotFound: + pynvml = None + +from typing import Iterable +from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.optimization import get_scheduler +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from PIL import Image + +from typing import Dict, List, Generator, Tuple +from scipy.interpolate import interp1d + +# defaults should be good for everyone +# TODO: add custom VAE support. should be simple with diffusers +parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner') +parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory') +parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.') +parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.') +parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.') +parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.') +parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate') +parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for') +parser.add_argument('--batch_size', type=int, default=1, help='Batch size') +parser.add_argument('--use_ema', type=bool, default=False, help='Use EMA for finetuning') +parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability +parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=bool, default=False, help='Enable gradient checkpointing') +parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=bool, default=False, help='Use 8-bit Adam optimizer') +parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1') +parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2') +parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay') +parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon') +parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.') +parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.') +parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.') +parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.') +parser.add_argument('--shuffle', dest='shuffle', type=bool, default=True, help='Shuffle dataset') +parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.') +parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB') +parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision') +parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') +parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') +args = parser.parse_args() + +os.makedirs(args.output_path, exist_ok=True) + +# remove hf_token from args so sneaky people don't steal it from the wandb logs +sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']} +run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb') + +# Inform the user of host, and various versions -- useful for debugging isseus. +print("RUN_NAME:", args.run_name) +print("HOST:", socket.gethostname()) +print("CUDA:", torch.version.cuda) +print("TORCH:", torch.__version__) +print("TRANSFORMERS:", transformers.__version__) +print("DIFFUSERS:", diffusers.__version__) +print("MODEL:", args.model) +print("FP16:", args.fp16) +print("RESOLUTION:", args.resolution) + +def get_gpu_ram() -> str: + """ + Returns memory usage statistics for the CPU, GPU, and Torch. + + :return: + """ + gpu_str = "" + torch_str = "" + try: + cudadev = torch.cuda.current_device() + nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev) + gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device) + gpu_total = int(gpu_info.total / 1E6) + gpu_free = int(gpu_info.free / 1E6) + gpu_used = int(gpu_info.used / 1E6) + gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \ + f"T: {gpu_total:,}mb) " + torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6) + torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6) + torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6) + torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6) + torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \ + f"{torch_reserved_max:,}mb, " \ + f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)" + except AssertionError: + pass + cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 + + resource.getrusage( + resource.RUSAGE_CHILDREN).ru_maxrss / 1E3) + cpu_vmem = psutil.virtual_memory() + cpu_free = int(cpu_vmem.free / 1E6) + return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \ + f"{gpu_str}" \ + f"{torch_str}" + +def _sort_by_ratio(bucket: tuple) -> float: + return bucket[0] / bucket[1] + +def _sort_by_area(bucket: tuple) -> float: + return bucket[0] * bucket[1] + +class ImageStore: + def __init__(self, data_dir: str) -> None: + self.data_dir = data_dir + + self.image_files = [] + [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] + + def __len__(self) -> int: + return len(self.image_files) + + # iterator returns images as PIL images and their index in the store + def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: + for f in range(len(self)): + yield Image.open(self.image_files[f]), f + + # get image by index + def get_image(self, index: int) -> Image.Image: + return Image.open(self.image_files[index]) + + # gets caption by removing the extension from the filename and replacing it with .txt + def get_caption(self, index: int) -> str: + filename = self.image_files[index].split('.')[0] + '.txt' + with open(filename, 'r') as f: + return f.read() + +class AspectBucket: + def __init__(self, store: ImageStore, + num_buckets: int, + batch_size: int, + bucket_side_min: int = 256, + bucket_side_max: int = 768, + bucket_side_increment: int = 64, + max_image_area: int = 512 * 768, + max_ratio: float = 2): + + self.requested_bucket_count = num_buckets + self.bucket_length_min = bucket_side_min + self.bucket_length_max = bucket_side_max + self.bucket_increment = bucket_side_increment + self.max_image_area = max_image_area + self.batch_size = batch_size + self.total_dropped = 0 + + if max_ratio <= 0: + self.max_ratio = float('inf') + else: + self.max_ratio = max_ratio + + self.store = store + self.buckets = [] + self._bucket_ratios = [] + self._bucket_interp = None + self.bucket_data: Dict[tuple, List[int]] = dict() + self.init_buckets() + self.fill_buckets() + + def init_buckets(self): + possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment)) + possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths) + if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio) + + buckets_by_ratio = {} + + # group the buckets by their aspect ratios + for bucket in possible_buckets: + w, h = bucket + # use precision to avoid spooky floats messing up your day + ratio = '{:.4e}'.format(w / h) + + if ratio not in buckets_by_ratio: + group = set() + buckets_by_ratio[ratio] = group + else: + group = buckets_by_ratio[ratio] + + group.add(bucket) + + # now we take the list of buckets we generated and pick the largest by area for each (the first sorted) + # then we put all of those in a list, sorted by the aspect ratio + # the square bucket (LxL) will be the first + unique_ratio_buckets = sorted([sorted(buckets, key=_sort_by_area)[-1] + for buckets in buckets_by_ratio.values()], key=_sort_by_ratio) + + # how many buckets to create for each side of the distribution + bucket_count_each = int(np.clip((self.requested_bucket_count + 1) / 2, 1, len(unique_ratio_buckets))) + + # we know that the requested_bucket_count must be an odd number, so the indices we calculate + # will include the square bucket and some linearly spaced buckets along the distribution + indices = {*np.linspace(0, len(unique_ratio_buckets) - 1, bucket_count_each, dtype=int)} + + # make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio + # here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above + buckets = sorted({*(unique_ratio_buckets[i] for i in indices), + *(tuple(reversed(unique_ratio_buckets[i])) for i in indices)}, key=_sort_by_ratio) + + self.buckets = buckets + + # cache the bucket ratios and the interpolator that will be used for calculating the best bucket later + # the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio, + # and the output is the bucket index in the self.buckets array + # to find the best fit we can just round that number to get the index + self._bucket_ratios = [w / h for w, h in buckets] + self._bucket_interp = interp1d(self._bucket_ratios, list(range(len(buckets))), assume_sorted=True, + fill_value=None) + + for b in buckets: + self.bucket_data[b] = [] + + def get_batch_count(self): + return sum(len(b) // self.batch_size for b in self.bucket_data.values()) + + def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: + """ + Generator that provides batches where the images in a batch fall on the same bucket + + Each element generated will be: + ((w, h), [image1, image2, ..., image{batch_size}]) + + where each image is an index into the dataset + :return: + """ + max_bucket_len = max(len(b) for b in self.bucket_data.values()) + index_schedule = list(range(max_bucket_len)) + random.shuffle(index_schedule) + + bucket_len_table = { + b: len(self.bucket_data[b]) for b in self.buckets + } + + bucket_schedule = [] + for i, b in enumerate(self.buckets): + bucket_schedule.extend([i] * (bucket_len_table[b] // self.batch_size)) + + random.shuffle(bucket_schedule) + + bucket_pos = { + b: 0 for b in self.buckets + } + + total_generated_by_bucket = { + b: 0 for b in self.buckets + } + + for bucket_index in bucket_schedule: + b = self.buckets[bucket_index] + i = bucket_pos[b] + bucket_len = bucket_len_table[b] + + batch = [] + while len(batch) != self.batch_size: + # advance in the schedule until we find an index that is contained in the bucket + k = index_schedule[i] + if k < bucket_len: + entry = self.bucket_data[b][k] + batch.append(entry) + + i += 1 + + total_generated_by_bucket[b] += self.batch_size + bucket_pos[b] = i + yield [idx for idx in batch] + + def fill_buckets(self): + entries = self.store.entries_iterator() + total_dropped = 0 + + for entry, index in tqdm.tqdm(entries, total=len(self.store)): + if not self._process_entry(entry, index): + total_dropped += 1 + + for b, values in self.bucket_data.items(): + # shuffle the entries for extra randomness and to make sure dropped elements are also random + random.shuffle(values) + + # make sure the buckets have an exact number of elements for the batch + to_drop = len(values) % self.batch_size + self.bucket_data[b] = list(values[:len(values) - to_drop]) + total_dropped += to_drop + + self.total_dropped = total_dropped + + def _process_entry(self, entry: Image.Image, index: int) -> bool: + aspect = entry.width / entry.height + + if aspect > self.max_ratio or (1 / aspect) > self.max_ratio: + return False + + best_bucket = self._bucket_interp(aspect) + + if best_bucket is None: + return False + + bucket = self.buckets[round(float(best_bucket))] + + self.bucket_data[bucket].append(index) + + return True + +class AspectBucketSampler(torch.utils.data.Sampler): + def __init__(self, bucket: AspectBucket): + super().__init__(None) + self.bucket = bucket + + def __iter__(self): + yield from self.bucket.get_batch_iterator() + + def __len__(self): + return self.bucket.get_batch_count() + +class AspectDataset(torch.utils.data.Dataset): + def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1): + self.store = store + self.tokenizer = tokenizer + self.ucg = ucg + + self.transforms = torchvision.transforms.Compose([ + torchvision.transforms.RandomHorizontalFlip(p=0.5), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return len(self.store) + + def __getitem__(self, item: int): + return_dict = {'pixel_values': None, 'input_ids': None} + + image_file = self.store.get_image(item) + return_dict['pixel_values'] = self.transforms(image_file) + if random.random() > self.ucg: + caption_file = self.store.get_caption(item) + else: + caption_file = '' + return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids + + return return_dict + + def collate_fn(self, examples): + pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None]) + pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = [example['input_ids'] for example in examples if example is not None] + padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True) + return { + 'pixel_values': pixel_values, + 'input_ids': padded_tokens.input_ids, + 'attention_mask': padded_tokens.attention_mask, + } + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + self.decay = decay + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + value = (1 + optimization_step) / (10 + optimization_step) + return 1 - min(self.decay, value) + + @torch.no_grad() + def step(self, parameters): + parameters = list(parameters) + + self.optimization_step += 1 + self.decay = self.get_decay(self.optimization_step) + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + tmp = self.decay * (s_param - param) + s_param.sub_(tmp) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + +def main(): + # get device. TODO: support multi-gpu + device = 'cpu' + if torch.cuda.is_available(): + device = 'cuda' + + print("DEVICE:", device) + + # setup fp16 stuff + scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) + + # Set seed + torch.manual_seed(args.seed) + print('RANDOM SEED:', args.seed) + + tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token) + text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token) + vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token) + unet = UNet2DConditionModel.from_pretrained(args.model, subfolder='unet', use_auth_token=args.hf_token) + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. + try: + import bitsandbytes as bnb + optimizer_cls = bnb.optim.AdamW8bit + except: + print('bitsandbytes not supported, using regular Adam optimizer') + optimizer_cls = torch.optim.AdamW + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.lr, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule='scaled_linear', + num_train_timesteps=1000, + tensor_format='pt' + ) + + # load dataset + + + store = ImageStore(args.dataset) + dataset = AspectDataset(store, tokenizer) + bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) + sampler = AspectBucketSampler(bucket) + + print(f'STORE_LEN: {len(store)}') + + train_dataloader = torch.utils.data.DataLoader( + dataset, + batch_sampler=sampler, + num_workers=4, + collate_fn=dataset.collate_fn + ) + + lr_scheduler = get_scheduler( + 'constant', + optimizer=optimizer + ) + + weight_dtype = torch.float16 if args.fp16 else torch.float32 + + # move models to device + vae = vae.to(device, dtype=weight_dtype) + unet = unet.to(device, dtype=torch.float32) + text_encoder = text_encoder.to(device, dtype=weight_dtype) + + # create ema + if args.use_ema: + ema_unet = EMAModel(unet.parameters()) + + print(get_gpu_ram()) + + num_steps_per_epoch = len(train_dataloader) + progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False) + global_step = 0 + + def save_checkpoint(): + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.save_pretrained(args.output_path) + + # train! + for epoch in range(args.epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Convert images to latent space + latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0] + + # Predict the noise residual and compute loss + with torch.autocast('cuda', enabled=args.fp16): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + # Backprop + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + optimizer.zero_grad() + + # Update EMA + if args.use_ema: + ema_unet.step(unet.parameters()) + + progress_bar.update(1) + global_step += 1 + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "epoch": epoch + } + progress_bar.set_postfix(logs) + run.log(logs) + + if global_step % args.save_steps == 0: + save_checkpoint() + + if global_step % args.image_log_steps == 0: + # get prompt from random batch + prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=None, # display safety checker to save memory + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to(device) + # inference + images = [] + with torch.no_grad(): + with torch.autocast('cuda', enabled=args.fp16): + for _ in range(args.image_log_amount): + images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) + # log images under single caption + run.log({'images': images}) + + # cleanup so we don't run out of memory + del pipeline + gc.collect() + + save_checkpoint() + + print(get_gpu_ram()) + print('Done!') + + +if __name__ == '__main__': + main() + +""" +import numpy as np +# save a sample +img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy() +img = ((img + 1.0) * 127.5).astype(np.uint8) +img = Image.fromarray(img) +img.save('sample.png') +break +"""