Merge pull request #28 from AstraliteHeart/main
Misc cleanups from pony-diffusion test run
This commit is contained in:
commit
2321e22fc1
|
@ -1,5 +1,9 @@
|
|||
# Install bitsandbytes:
|
||||
# `nvcc --version` to get CUDA version.
|
||||
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
|
||||
# Example Usage:
|
||||
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
|
||||
import argparse
|
||||
import socket
|
||||
|
@ -168,7 +172,7 @@ class ImageStore:
|
|||
# get image by index
|
||||
def get_image(self, index: int) -> Image.Image:
|
||||
return Image.open(self.image_files[index])
|
||||
|
||||
|
||||
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||
def get_caption(self, index: int) -> str:
|
||||
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt'
|
||||
|
@ -216,7 +220,7 @@ class AspectBucket:
|
|||
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
||||
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
||||
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
|
||||
|
||||
|
||||
buckets_by_ratio = {}
|
||||
|
||||
# group the buckets by their aspect ratios
|
||||
|
@ -263,7 +267,7 @@ class AspectBucket:
|
|||
|
||||
for b in buckets:
|
||||
self.bucket_data[b] = []
|
||||
|
||||
|
||||
def get_batch_count(self):
|
||||
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
||||
|
||||
|
@ -320,7 +324,7 @@ class AspectBucket:
|
|||
total_generated_by_bucket[b] += self.batch_size
|
||||
bucket_pos[b] = i
|
||||
yield [idx for idx in batch]
|
||||
|
||||
|
||||
def fill_buckets(self):
|
||||
entries = self.store.entries_iterator()
|
||||
total_dropped = 0
|
||||
|
@ -339,18 +343,18 @@ class AspectBucket:
|
|||
total_dropped += to_drop
|
||||
|
||||
self.total_dropped = total_dropped
|
||||
|
||||
|
||||
def _process_entry(self, entry: Image.Image, index: int) -> bool:
|
||||
aspect = entry.width / entry.height
|
||||
|
||||
|
||||
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
|
||||
return False
|
||||
|
||||
|
||||
best_bucket = self._bucket_interp(aspect)
|
||||
|
||||
if best_bucket is None:
|
||||
return False
|
||||
|
||||
|
||||
bucket = self.buckets[round(float(best_bucket))]
|
||||
|
||||
self.bucket_data[bucket].append(index)
|
||||
|
@ -365,13 +369,13 @@ class AspectBucketSampler(torch.utils.data.Sampler):
|
|||
self.bucket = bucket
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
# subsample the bucket to only include the elements that are assigned to this rank
|
||||
indices = self.bucket.get_batch_iterator()
|
||||
indices = list(indices)[self.rank::self.num_replicas]
|
||||
return iter(indices)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.bucket.get_batch_count() // self.num_replicas
|
||||
|
||||
|
@ -386,7 +390,7 @@ class AspectDataset(torch.utils.data.Dataset):
|
|||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.store)
|
||||
|
||||
|
@ -490,7 +494,7 @@ def main():
|
|||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
|
||||
print("DEVICE:", device)
|
||||
|
||||
# setup fp16 stuff
|
||||
|
@ -511,7 +515,7 @@ def main():
|
|||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
|
||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -521,7 +525,7 @@ def main():
|
|||
optimizer_cls = torch.optim.AdamW
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
|
||||
optimizer = optimizer_cls(
|
||||
unet.parameters(),
|
||||
lr=args.lr,
|
||||
|
@ -535,7 +539,6 @@ def main():
|
|||
beta_end=0.012,
|
||||
beta_schedule='scaled_linear',
|
||||
num_train_timesteps=1000,
|
||||
tensor_format='pt'
|
||||
)
|
||||
|
||||
# load dataset
|
||||
|
@ -575,7 +578,7 @@ def main():
|
|||
# create ema
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet.parameters())
|
||||
|
||||
|
||||
print(get_gpu_ram())
|
||||
|
||||
num_steps_per_epoch = len(train_dataloader)
|
||||
|
@ -632,7 +635,7 @@ def main():
|
|||
# Predict the noise residual and compute loss
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Backprop and all reduce
|
||||
|
@ -645,7 +648,7 @@ def main():
|
|||
# Update EMA
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet.parameters())
|
||||
|
||||
|
||||
# perf
|
||||
b_end = time.perf_counter()
|
||||
seconds_per_step = b_end - b_start
|
||||
|
@ -656,7 +659,7 @@ def main():
|
|||
|
||||
# All reduce loss
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
@ -673,7 +676,7 @@ def main():
|
|||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_checkpoint()
|
||||
|
||||
|
||||
if global_step % args.image_log_steps == 0:
|
||||
if rank == 0:
|
||||
# get prompt from random batch
|
||||
|
@ -702,7 +705,7 @@ def main():
|
|||
del pipeline
|
||||
gc.collect()
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
if rank == 0:
|
||||
save_checkpoint()
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@ streamlit>=0.73.1
|
|||
einops==0.3.0
|
||||
torch-fidelity==0.3.0
|
||||
transformers==4.19.2
|
||||
torchmetrics==0.6.0
|
||||
diffusers==0.7.1
|
||||
torchmetrics==0.7.0
|
||||
kornia==0.6
|
||||
gradio
|
||||
git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers
|
||||
|
@ -19,4 +20,5 @@ git+https://github.com/openai/CLIP.git@main#egg=clip
|
|||
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||
webdataset
|
||||
wandb
|
||||
fairscale
|
||||
fairscale
|
||||
pynvml==11.4.1
|
Loading…
Reference in New Issue