Add bitsandbytes instructions

This commit is contained in:
Astralite Heart 2022-11-05 10:38:35 +00:00
parent bb153c1ec0
commit 74ef98667c
1 changed files with 31 additions and 27 deletions

View File

@ -1,5 +1,9 @@
# Install bitsandbytes:
# `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage:
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import socket
@ -146,16 +150,16 @@ class ImageStore:
def __len__(self) -> int:
return len(self.image_files)
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]), f
# get image by index
def get_image(self, index: int) -> Image.Image:
return Image.open(self.image_files[index])
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, index: int) -> str:
filename = self.image_files[index].split('.')[0] + '.txt'
@ -177,7 +181,7 @@ class AspectBucket:
bucket_side_increment: int = 64,
max_image_area: int = 512 * 768,
max_ratio: float = 2):
self.requested_bucket_count = num_buckets
self.bucket_length_min = bucket_side_min
self.bucket_length_max = bucket_side_max
@ -190,7 +194,7 @@ class AspectBucket:
self.max_ratio = float('inf')
else:
self.max_ratio = max_ratio
self.store = store
self.buckets = []
self._bucket_ratios = []
@ -198,12 +202,12 @@ class AspectBucket:
self.bucket_data: Dict[tuple, List[int]] = dict()
self.init_buckets()
self.fill_buckets()
def init_buckets(self):
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
buckets_by_ratio = {}
# group the buckets by their aspect ratios
@ -250,10 +254,10 @@ class AspectBucket:
for b in buckets:
self.bucket_data[b] = []
def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
"""
Generator that provides batches where the images in a batch fall on the same bucket
@ -304,7 +308,7 @@ class AspectBucket:
total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i
yield [idx for idx in batch]
def fill_buckets(self):
entries = self.store.entries_iterator()
total_dropped = 0
@ -323,18 +327,18 @@ class AspectBucket:
total_dropped += to_drop
self.total_dropped = total_dropped
def _process_entry(self, entry: Image.Image, index: int) -> bool:
aspect = entry.width / entry.height
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
return False
best_bucket = self._bucket_interp(aspect)
if best_bucket is None:
return False
bucket = self.buckets[round(float(best_bucket))]
self.bucket_data[bucket].append(index)
@ -349,13 +353,13 @@ class AspectBucketSampler(torch.utils.data.Sampler):
self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self):
# subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.bucket.get_batch_count() // self.num_replicas
@ -370,7 +374,7 @@ class AspectDataset(torch.utils.data.Dataset):
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5]),
])
def __len__(self):
return len(self.store)
@ -474,7 +478,7 @@ def main():
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
device = torch.device('cuda')
print("DEVICE:", device)
# setup fp16 stuff
@ -495,7 +499,7 @@ def main():
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try:
import bitsandbytes as bnb
@ -505,7 +509,7 @@ def main():
optimizer_cls = torch.optim.AdamW
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
unet.parameters(),
lr=args.lr,
@ -555,7 +559,7 @@ def main():
# create ema
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
print(get_gpu_ram())
num_steps_per_epoch = len(train_dataloader)
@ -612,7 +616,7 @@ def main():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Backprop and all reduce
@ -625,7 +629,7 @@ def main():
# Update EMA
if args.use_ema:
ema_unet.step(unet.parameters())
# perf
b_end = time.perf_counter()
seconds_per_step = b_end - b_start
@ -636,7 +640,7 @@ def main():
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
@ -653,7 +657,7 @@ def main():
if global_step % args.save_steps == 0:
save_checkpoint()
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
@ -682,7 +686,7 @@ def main():
del pipeline
gc.collect()
torch.distributed.barrier()
if rank == 0:
save_checkpoint()