diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 444941a..8631ccf 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -1,5 +1,9 @@ +# Install bitsandbytes: +# `nvcc --version` to get CUDA version. +# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # Example Usage: -# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True import argparse import socket @@ -146,16 +150,16 @@ class ImageStore: def __len__(self) -> int: return len(self.image_files) - + # iterator returns images as PIL images and their index in the store def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: for f in range(len(self)): yield Image.open(self.image_files[f]), f - + # get image by index def get_image(self, index: int) -> Image.Image: return Image.open(self.image_files[index]) - + # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, index: int) -> str: filename = self.image_files[index].split('.')[0] + '.txt' @@ -177,7 +181,7 @@ class AspectBucket: bucket_side_increment: int = 64, max_image_area: int = 512 * 768, max_ratio: float = 2): - + self.requested_bucket_count = num_buckets self.bucket_length_min = bucket_side_min self.bucket_length_max = bucket_side_max @@ -190,7 +194,7 @@ class AspectBucket: self.max_ratio = float('inf') else: self.max_ratio = max_ratio - + self.store = store self.buckets = [] self._bucket_ratios = [] @@ -198,12 +202,12 @@ class AspectBucket: self.bucket_data: Dict[tuple, List[int]] = dict() self.init_buckets() self.fill_buckets() - + def init_buckets(self): possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment)) possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths) if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio) - + buckets_by_ratio = {} # group the buckets by their aspect ratios @@ -250,10 +254,10 @@ class AspectBucket: for b in buckets: self.bucket_data[b] = [] - + def get_batch_count(self): return sum(len(b) // self.batch_size for b in self.bucket_data.values()) - + def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: """ Generator that provides batches where the images in a batch fall on the same bucket @@ -304,7 +308,7 @@ class AspectBucket: total_generated_by_bucket[b] += self.batch_size bucket_pos[b] = i yield [idx for idx in batch] - + def fill_buckets(self): entries = self.store.entries_iterator() total_dropped = 0 @@ -323,18 +327,18 @@ class AspectBucket: total_dropped += to_drop self.total_dropped = total_dropped - + def _process_entry(self, entry: Image.Image, index: int) -> bool: aspect = entry.width / entry.height - + if aspect > self.max_ratio or (1 / aspect) > self.max_ratio: return False - + best_bucket = self._bucket_interp(aspect) if best_bucket is None: return False - + bucket = self.buckets[round(float(best_bucket))] self.bucket_data[bucket].append(index) @@ -349,13 +353,13 @@ class AspectBucketSampler(torch.utils.data.Sampler): self.bucket = bucket self.num_replicas = num_replicas self.rank = rank - + def __iter__(self): # subsample the bucket to only include the elements that are assigned to this rank indices = self.bucket.get_batch_iterator() indices = list(indices)[self.rank::self.num_replicas] return iter(indices) - + def __len__(self): return self.bucket.get_batch_count() // self.num_replicas @@ -370,7 +374,7 @@ class AspectDataset(torch.utils.data.Dataset): torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5]), ]) - + def __len__(self): return len(self.store) @@ -474,7 +478,7 @@ def main(): run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb') device = torch.device('cuda') - + print("DEVICE:", device) # setup fp16 stuff @@ -495,7 +499,7 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - + if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: import bitsandbytes as bnb @@ -505,7 +509,7 @@ def main(): optimizer_cls = torch.optim.AdamW else: optimizer_cls = torch.optim.AdamW - + optimizer = optimizer_cls( unet.parameters(), lr=args.lr, @@ -555,7 +559,7 @@ def main(): # create ema if args.use_ema: ema_unet = EMAModel(unet.parameters()) - + print(get_gpu_ram()) num_steps_per_epoch = len(train_dataloader) @@ -612,7 +616,7 @@ def main(): # Predict the noise residual and compute loss with torch.autocast('cuda', enabled=args.fp16): noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - + loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") # Backprop and all reduce @@ -625,7 +629,7 @@ def main(): # Update EMA if args.use_ema: ema_unet.step(unet.parameters()) - + # perf b_end = time.perf_counter() seconds_per_step = b_end - b_start @@ -636,7 +640,7 @@ def main(): # All reduce loss torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) - + if rank == 0: progress_bar.update(1) global_step += 1 @@ -653,7 +657,7 @@ def main(): if global_step % args.save_steps == 0: save_checkpoint() - + if global_step % args.image_log_steps == 0: if rank == 0: # get prompt from random batch @@ -682,7 +686,7 @@ def main(): del pipeline gc.collect() torch.distributed.barrier() - + if rank == 0: save_checkpoint()