Merge pull request #28 from AstraliteHeart/main

Misc cleanups from pony-diffusion test run
This commit is contained in:
Anthony Mercurio 2022-11-05 08:55:48 -07:00 committed by GitHub
commit 2321e22fc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 24 deletions

View File

@ -1,5 +1,9 @@
# Install bitsandbytes:
# `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage:
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import socket
@ -168,7 +172,7 @@ class ImageStore:
# get image by index
def get_image(self, index: int) -> Image.Image:
return Image.open(self.image_files[index])
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, index: int) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt'
@ -216,7 +220,7 @@ class AspectBucket:
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
buckets_by_ratio = {}
# group the buckets by their aspect ratios
@ -263,7 +267,7 @@ class AspectBucket:
for b in buckets:
self.bucket_data[b] = []
def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
@ -320,7 +324,7 @@ class AspectBucket:
total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i
yield [idx for idx in batch]
def fill_buckets(self):
entries = self.store.entries_iterator()
total_dropped = 0
@ -339,18 +343,18 @@ class AspectBucket:
total_dropped += to_drop
self.total_dropped = total_dropped
def _process_entry(self, entry: Image.Image, index: int) -> bool:
aspect = entry.width / entry.height
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
return False
best_bucket = self._bucket_interp(aspect)
if best_bucket is None:
return False
bucket = self.buckets[round(float(best_bucket))]
self.bucket_data[bucket].append(index)
@ -365,13 +369,13 @@ class AspectBucketSampler(torch.utils.data.Sampler):
self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self):
# subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.bucket.get_batch_count() // self.num_replicas
@ -386,7 +390,7 @@ class AspectDataset(torch.utils.data.Dataset):
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5]),
])
def __len__(self):
return len(self.store)
@ -490,7 +494,7 @@ def main():
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
device = torch.device('cuda')
print("DEVICE:", device)
# setup fp16 stuff
@ -511,7 +515,7 @@ def main():
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try:
import bitsandbytes as bnb
@ -521,7 +525,7 @@ def main():
optimizer_cls = torch.optim.AdamW
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
unet.parameters(),
lr=args.lr,
@ -535,7 +539,6 @@ def main():
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
tensor_format='pt'
)
# load dataset
@ -575,7 +578,7 @@ def main():
# create ema
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
print(get_gpu_ram())
num_steps_per_epoch = len(train_dataloader)
@ -632,7 +635,7 @@ def main():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Backprop and all reduce
@ -645,7 +648,7 @@ def main():
# Update EMA
if args.use_ema:
ema_unet.step(unet.parameters())
# perf
b_end = time.perf_counter()
seconds_per_step = b_end - b_start
@ -656,7 +659,7 @@ def main():
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
@ -673,7 +676,7 @@ def main():
if global_step % args.save_steps == 0:
save_checkpoint()
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
@ -702,7 +705,7 @@ def main():
del pipeline
gc.collect()
torch.distributed.barrier()
if rank == 0:
save_checkpoint()

View File

@ -11,7 +11,8 @@ streamlit>=0.73.1
einops==0.3.0
torch-fidelity==0.3.0
transformers==4.19.2
torchmetrics==0.6.0
diffusers==0.7.1
torchmetrics==0.7.0
kornia==0.6
gradio
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
@ -19,4 +20,5 @@ git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
webdataset
wandb
fairscale
fairscale
pynvml==11.4.1