Add bitsandbytes instructions
This commit is contained in:
parent
bb153c1ec0
commit
74ef98667c
|
@ -1,5 +1,9 @@
|
||||||
|
# Install bitsandbytes:
|
||||||
|
# `nvcc --version` to get CUDA version.
|
||||||
|
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
|
||||||
# Example Usage:
|
# Example Usage:
|
||||||
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
|
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import socket
|
import socket
|
||||||
|
@ -146,16 +150,16 @@ class ImageStore:
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.image_files)
|
return len(self.image_files)
|
||||||
|
|
||||||
# iterator returns images as PIL images and their index in the store
|
# iterator returns images as PIL images and their index in the store
|
||||||
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||||
for f in range(len(self)):
|
for f in range(len(self)):
|
||||||
yield Image.open(self.image_files[f]), f
|
yield Image.open(self.image_files[f]), f
|
||||||
|
|
||||||
# get image by index
|
# get image by index
|
||||||
def get_image(self, index: int) -> Image.Image:
|
def get_image(self, index: int) -> Image.Image:
|
||||||
return Image.open(self.image_files[index])
|
return Image.open(self.image_files[index])
|
||||||
|
|
||||||
# gets caption by removing the extension from the filename and replacing it with .txt
|
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||||
def get_caption(self, index: int) -> str:
|
def get_caption(self, index: int) -> str:
|
||||||
filename = self.image_files[index].split('.')[0] + '.txt'
|
filename = self.image_files[index].split('.')[0] + '.txt'
|
||||||
|
@ -177,7 +181,7 @@ class AspectBucket:
|
||||||
bucket_side_increment: int = 64,
|
bucket_side_increment: int = 64,
|
||||||
max_image_area: int = 512 * 768,
|
max_image_area: int = 512 * 768,
|
||||||
max_ratio: float = 2):
|
max_ratio: float = 2):
|
||||||
|
|
||||||
self.requested_bucket_count = num_buckets
|
self.requested_bucket_count = num_buckets
|
||||||
self.bucket_length_min = bucket_side_min
|
self.bucket_length_min = bucket_side_min
|
||||||
self.bucket_length_max = bucket_side_max
|
self.bucket_length_max = bucket_side_max
|
||||||
|
@ -190,7 +194,7 @@ class AspectBucket:
|
||||||
self.max_ratio = float('inf')
|
self.max_ratio = float('inf')
|
||||||
else:
|
else:
|
||||||
self.max_ratio = max_ratio
|
self.max_ratio = max_ratio
|
||||||
|
|
||||||
self.store = store
|
self.store = store
|
||||||
self.buckets = []
|
self.buckets = []
|
||||||
self._bucket_ratios = []
|
self._bucket_ratios = []
|
||||||
|
@ -198,12 +202,12 @@ class AspectBucket:
|
||||||
self.bucket_data: Dict[tuple, List[int]] = dict()
|
self.bucket_data: Dict[tuple, List[int]] = dict()
|
||||||
self.init_buckets()
|
self.init_buckets()
|
||||||
self.fill_buckets()
|
self.fill_buckets()
|
||||||
|
|
||||||
def init_buckets(self):
|
def init_buckets(self):
|
||||||
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
||||||
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
||||||
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
|
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
|
||||||
|
|
||||||
buckets_by_ratio = {}
|
buckets_by_ratio = {}
|
||||||
|
|
||||||
# group the buckets by their aspect ratios
|
# group the buckets by their aspect ratios
|
||||||
|
@ -250,10 +254,10 @@ class AspectBucket:
|
||||||
|
|
||||||
for b in buckets:
|
for b in buckets:
|
||||||
self.bucket_data[b] = []
|
self.bucket_data[b] = []
|
||||||
|
|
||||||
def get_batch_count(self):
|
def get_batch_count(self):
|
||||||
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
||||||
|
|
||||||
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generator that provides batches where the images in a batch fall on the same bucket
|
Generator that provides batches where the images in a batch fall on the same bucket
|
||||||
|
@ -304,7 +308,7 @@ class AspectBucket:
|
||||||
total_generated_by_bucket[b] += self.batch_size
|
total_generated_by_bucket[b] += self.batch_size
|
||||||
bucket_pos[b] = i
|
bucket_pos[b] = i
|
||||||
yield [idx for idx in batch]
|
yield [idx for idx in batch]
|
||||||
|
|
||||||
def fill_buckets(self):
|
def fill_buckets(self):
|
||||||
entries = self.store.entries_iterator()
|
entries = self.store.entries_iterator()
|
||||||
total_dropped = 0
|
total_dropped = 0
|
||||||
|
@ -323,18 +327,18 @@ class AspectBucket:
|
||||||
total_dropped += to_drop
|
total_dropped += to_drop
|
||||||
|
|
||||||
self.total_dropped = total_dropped
|
self.total_dropped = total_dropped
|
||||||
|
|
||||||
def _process_entry(self, entry: Image.Image, index: int) -> bool:
|
def _process_entry(self, entry: Image.Image, index: int) -> bool:
|
||||||
aspect = entry.width / entry.height
|
aspect = entry.width / entry.height
|
||||||
|
|
||||||
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
|
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
best_bucket = self._bucket_interp(aspect)
|
best_bucket = self._bucket_interp(aspect)
|
||||||
|
|
||||||
if best_bucket is None:
|
if best_bucket is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
bucket = self.buckets[round(float(best_bucket))]
|
bucket = self.buckets[round(float(best_bucket))]
|
||||||
|
|
||||||
self.bucket_data[bucket].append(index)
|
self.bucket_data[bucket].append(index)
|
||||||
|
@ -349,13 +353,13 @@ class AspectBucketSampler(torch.utils.data.Sampler):
|
||||||
self.bucket = bucket
|
self.bucket = bucket
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# subsample the bucket to only include the elements that are assigned to this rank
|
# subsample the bucket to only include the elements that are assigned to this rank
|
||||||
indices = self.bucket.get_batch_iterator()
|
indices = self.bucket.get_batch_iterator()
|
||||||
indices = list(indices)[self.rank::self.num_replicas]
|
indices = list(indices)[self.rank::self.num_replicas]
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.bucket.get_batch_count() // self.num_replicas
|
return self.bucket.get_batch_count() // self.num_replicas
|
||||||
|
|
||||||
|
@ -370,7 +374,7 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToTensor(),
|
||||||
torchvision.transforms.Normalize([0.5], [0.5]),
|
torchvision.transforms.Normalize([0.5], [0.5]),
|
||||||
])
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.store)
|
return len(self.store)
|
||||||
|
|
||||||
|
@ -474,7 +478,7 @@ def main():
|
||||||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
|
|
||||||
print("DEVICE:", device)
|
print("DEVICE:", device)
|
||||||
|
|
||||||
# setup fp16 stuff
|
# setup fp16 stuff
|
||||||
|
@ -495,7 +499,7 @@ def main():
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -505,7 +509,7 @@ def main():
|
||||||
optimizer_cls = torch.optim.AdamW
|
optimizer_cls = torch.optim.AdamW
|
||||||
else:
|
else:
|
||||||
optimizer_cls = torch.optim.AdamW
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
|
||||||
optimizer = optimizer_cls(
|
optimizer = optimizer_cls(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
|
@ -555,7 +559,7 @@ def main():
|
||||||
# create ema
|
# create ema
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet = EMAModel(unet.parameters())
|
ema_unet = EMAModel(unet.parameters())
|
||||||
|
|
||||||
print(get_gpu_ram())
|
print(get_gpu_ram())
|
||||||
|
|
||||||
num_steps_per_epoch = len(train_dataloader)
|
num_steps_per_epoch = len(train_dataloader)
|
||||||
|
@ -612,7 +616,7 @@ def main():
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual and compute loss
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||||
|
|
||||||
# Backprop and all reduce
|
# Backprop and all reduce
|
||||||
|
@ -625,7 +629,7 @@ def main():
|
||||||
# Update EMA
|
# Update EMA
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet.step(unet.parameters())
|
ema_unet.step(unet.parameters())
|
||||||
|
|
||||||
# perf
|
# perf
|
||||||
b_end = time.perf_counter()
|
b_end = time.perf_counter()
|
||||||
seconds_per_step = b_end - b_start
|
seconds_per_step = b_end - b_start
|
||||||
|
@ -636,7 +640,7 @@ def main():
|
||||||
|
|
||||||
# All reduce loss
|
# All reduce loss
|
||||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -653,7 +657,7 @@ def main():
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
save_checkpoint()
|
save_checkpoint()
|
||||||
|
|
||||||
if global_step % args.image_log_steps == 0:
|
if global_step % args.image_log_steps == 0:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# get prompt from random batch
|
# get prompt from random batch
|
||||||
|
@ -682,7 +686,7 @@ def main():
|
||||||
del pipeline
|
del pipeline
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
save_checkpoint()
|
save_checkpoint()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue