diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 747265f..8d20303 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -1,5 +1,9 @@ +# Install bitsandbytes: +# `nvcc --version` to get CUDA version. +# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # Example Usage: -# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True import argparse import socket @@ -168,7 +172,7 @@ class ImageStore: # get image by index def get_image(self, index: int) -> Image.Image: return Image.open(self.image_files[index]) - + # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, index: int) -> str: filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt' @@ -216,7 +220,7 @@ class AspectBucket: possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment)) possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths) if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio) - + buckets_by_ratio = {} # group the buckets by their aspect ratios @@ -263,7 +267,7 @@ class AspectBucket: for b in buckets: self.bucket_data[b] = [] - + def get_batch_count(self): return sum(len(b) // self.batch_size for b in self.bucket_data.values()) @@ -320,7 +324,7 @@ class AspectBucket: total_generated_by_bucket[b] += self.batch_size bucket_pos[b] = i yield [idx for idx in batch] - + def fill_buckets(self): entries = self.store.entries_iterator() total_dropped = 0 @@ -339,18 +343,18 @@ class AspectBucket: total_dropped += to_drop self.total_dropped = total_dropped - + def _process_entry(self, entry: Image.Image, index: int) -> bool: aspect = entry.width / entry.height - + if aspect > self.max_ratio or (1 / aspect) > self.max_ratio: return False - + best_bucket = self._bucket_interp(aspect) if best_bucket is None: return False - + bucket = self.buckets[round(float(best_bucket))] self.bucket_data[bucket].append(index) @@ -365,13 +369,13 @@ class AspectBucketSampler(torch.utils.data.Sampler): self.bucket = bucket self.num_replicas = num_replicas self.rank = rank - + def __iter__(self): # subsample the bucket to only include the elements that are assigned to this rank indices = self.bucket.get_batch_iterator() indices = list(indices)[self.rank::self.num_replicas] return iter(indices) - + def __len__(self): return self.bucket.get_batch_count() // self.num_replicas @@ -386,7 +390,7 @@ class AspectDataset(torch.utils.data.Dataset): torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5]), ]) - + def __len__(self): return len(self.store) @@ -490,7 +494,7 @@ def main(): run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb') device = torch.device('cuda') - + print("DEVICE:", device) # setup fp16 stuff @@ -511,7 +515,7 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - + if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: import bitsandbytes as bnb @@ -521,7 +525,7 @@ def main(): optimizer_cls = torch.optim.AdamW else: optimizer_cls = torch.optim.AdamW - + optimizer = optimizer_cls( unet.parameters(), lr=args.lr, @@ -535,7 +539,6 @@ def main(): beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, - tensor_format='pt' ) # load dataset @@ -575,7 +578,7 @@ def main(): # create ema if args.use_ema: ema_unet = EMAModel(unet.parameters()) - + print(get_gpu_ram()) num_steps_per_epoch = len(train_dataloader) @@ -632,7 +635,7 @@ def main(): # Predict the noise residual and compute loss with torch.autocast('cuda', enabled=args.fp16): noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - + loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") # Backprop and all reduce @@ -645,7 +648,7 @@ def main(): # Update EMA if args.use_ema: ema_unet.step(unet.parameters()) - + # perf b_end = time.perf_counter() seconds_per_step = b_end - b_start @@ -656,7 +659,7 @@ def main(): # All reduce loss torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) - + if rank == 0: progress_bar.update(1) global_step += 1 @@ -673,7 +676,7 @@ def main(): if global_step % args.save_steps == 0: save_checkpoint() - + if global_step % args.image_log_steps == 0: if rank == 0: # get prompt from random batch @@ -702,7 +705,7 @@ def main(): del pipeline gc.collect() torch.distributed.barrier() - + if rank == 0: save_checkpoint() diff --git a/requirements.txt b/requirements.txt index 82c04b2..3f9c324 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,8 @@ streamlit>=0.73.1 einops==0.3.0 torch-fidelity==0.3.0 transformers==4.19.2 -torchmetrics==0.6.0 +diffusers==0.7.1 +torchmetrics==0.7.0 kornia==0.6 gradio git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers @@ -19,4 +20,5 @@ git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion webdataset wandb -fairscale \ No newline at end of file +fairscale +pynvml==11.4.1 \ No newline at end of file