Add bitsandbytes instructions
This commit is contained in:
parent
bb153c1ec0
commit
74ef98667c
|
@ -1,5 +1,9 @@
|
|||
# Install bitsandbytes:
|
||||
# `nvcc --version` to get CUDA version.
|
||||
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
|
||||
# Example Usage:
|
||||
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
|
||||
import argparse
|
||||
import socket
|
||||
|
@ -146,16 +150,16 @@ class ImageStore:
|
|||
|
||||
def __len__(self) -> int:
|
||||
return len(self.image_files)
|
||||
|
||||
|
||||
# iterator returns images as PIL images and their index in the store
|
||||
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||
for f in range(len(self)):
|
||||
yield Image.open(self.image_files[f]), f
|
||||
|
||||
|
||||
# get image by index
|
||||
def get_image(self, index: int) -> Image.Image:
|
||||
return Image.open(self.image_files[index])
|
||||
|
||||
|
||||
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||
def get_caption(self, index: int) -> str:
|
||||
filename = self.image_files[index].split('.')[0] + '.txt'
|
||||
|
@ -177,7 +181,7 @@ class AspectBucket:
|
|||
bucket_side_increment: int = 64,
|
||||
max_image_area: int = 512 * 768,
|
||||
max_ratio: float = 2):
|
||||
|
||||
|
||||
self.requested_bucket_count = num_buckets
|
||||
self.bucket_length_min = bucket_side_min
|
||||
self.bucket_length_max = bucket_side_max
|
||||
|
@ -190,7 +194,7 @@ class AspectBucket:
|
|||
self.max_ratio = float('inf')
|
||||
else:
|
||||
self.max_ratio = max_ratio
|
||||
|
||||
|
||||
self.store = store
|
||||
self.buckets = []
|
||||
self._bucket_ratios = []
|
||||
|
@ -198,12 +202,12 @@ class AspectBucket:
|
|||
self.bucket_data: Dict[tuple, List[int]] = dict()
|
||||
self.init_buckets()
|
||||
self.fill_buckets()
|
||||
|
||||
|
||||
def init_buckets(self):
|
||||
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
||||
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
||||
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
|
||||
|
||||
|
||||
buckets_by_ratio = {}
|
||||
|
||||
# group the buckets by their aspect ratios
|
||||
|
@ -250,10 +254,10 @@ class AspectBucket:
|
|||
|
||||
for b in buckets:
|
||||
self.bucket_data[b] = []
|
||||
|
||||
|
||||
def get_batch_count(self):
|
||||
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
||||
|
||||
|
||||
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
||||
"""
|
||||
Generator that provides batches where the images in a batch fall on the same bucket
|
||||
|
@ -304,7 +308,7 @@ class AspectBucket:
|
|||
total_generated_by_bucket[b] += self.batch_size
|
||||
bucket_pos[b] = i
|
||||
yield [idx for idx in batch]
|
||||
|
||||
|
||||
def fill_buckets(self):
|
||||
entries = self.store.entries_iterator()
|
||||
total_dropped = 0
|
||||
|
@ -323,18 +327,18 @@ class AspectBucket:
|
|||
total_dropped += to_drop
|
||||
|
||||
self.total_dropped = total_dropped
|
||||
|
||||
|
||||
def _process_entry(self, entry: Image.Image, index: int) -> bool:
|
||||
aspect = entry.width / entry.height
|
||||
|
||||
|
||||
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
|
||||
return False
|
||||
|
||||
|
||||
best_bucket = self._bucket_interp(aspect)
|
||||
|
||||
if best_bucket is None:
|
||||
return False
|
||||
|
||||
|
||||
bucket = self.buckets[round(float(best_bucket))]
|
||||
|
||||
self.bucket_data[bucket].append(index)
|
||||
|
@ -349,13 +353,13 @@ class AspectBucketSampler(torch.utils.data.Sampler):
|
|||
self.bucket = bucket
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
# subsample the bucket to only include the elements that are assigned to this rank
|
||||
indices = self.bucket.get_batch_iterator()
|
||||
indices = list(indices)[self.rank::self.num_replicas]
|
||||
return iter(indices)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.bucket.get_batch_count() // self.num_replicas
|
||||
|
||||
|
@ -370,7 +374,7 @@ class AspectDataset(torch.utils.data.Dataset):
|
|||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.store)
|
||||
|
||||
|
@ -474,7 +478,7 @@ def main():
|
|||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
|
||||
print("DEVICE:", device)
|
||||
|
||||
# setup fp16 stuff
|
||||
|
@ -495,7 +499,7 @@ def main():
|
|||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
|
||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -505,7 +509,7 @@ def main():
|
|||
optimizer_cls = torch.optim.AdamW
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
|
||||
optimizer = optimizer_cls(
|
||||
unet.parameters(),
|
||||
lr=args.lr,
|
||||
|
@ -555,7 +559,7 @@ def main():
|
|||
# create ema
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet.parameters())
|
||||
|
||||
|
||||
print(get_gpu_ram())
|
||||
|
||||
num_steps_per_epoch = len(train_dataloader)
|
||||
|
@ -612,7 +616,7 @@ def main():
|
|||
# Predict the noise residual and compute loss
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Backprop and all reduce
|
||||
|
@ -625,7 +629,7 @@ def main():
|
|||
# Update EMA
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet.parameters())
|
||||
|
||||
|
||||
# perf
|
||||
b_end = time.perf_counter()
|
||||
seconds_per_step = b_end - b_start
|
||||
|
@ -636,7 +640,7 @@ def main():
|
|||
|
||||
# All reduce loss
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
@ -653,7 +657,7 @@ def main():
|
|||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_checkpoint()
|
||||
|
||||
|
||||
if global_step % args.image_log_steps == 0:
|
||||
if rank == 0:
|
||||
# get prompt from random batch
|
||||
|
@ -682,7 +686,7 @@ def main():
|
|||
del pipeline
|
||||
gc.collect()
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
if rank == 0:
|
||||
save_checkpoint()
|
||||
|
||||
|
|
Loading…
Reference in New Issue