Implement distributed training

This commit is contained in:
Anthony Mercurio 2022-11-01 13:28:12 -07:00 committed by GitHub
parent ef25ccac36
commit 1a94c70736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 127 additions and 77 deletions

View File

@ -1,3 +1,6 @@
# Example Usage:
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import socket
import torch
@ -13,6 +16,7 @@ import psutil
import pynvml
import wandb
import gc
import time
import itertools
import numpy as np
@ -31,6 +35,8 @@ from PIL import Image
from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d
torch.backends.cuda.matmul.allow_tf32 = True
# defaults should be good for everyone
# TODO: add custom VAE support. should be simple with diffusers
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
@ -62,11 +68,21 @@ parser.add_argument('--image_log_steps', type=int, default=100, help='Number of
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
def setup():
torch.distributed.init_process_group("nccl", init_method="env://")
# remove hf_token from args so sneaky people don't steal it from the wandb logs
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
def cleanup():
torch.distributed.destroy_process_group()
def get_rank() -> int:
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
def get_world_size() -> int:
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
# Inform the user of host, and various versions -- useful for debugging isseus.
print("RUN_NAME:", args.run_name)
@ -322,18 +338,25 @@ class AspectBucket:
self.bucket_data[bucket].append(index)
del entry
return True
class AspectBucketSampler(torch.utils.data.Sampler):
def __init__(self, bucket: AspectBucket):
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
super().__init__(None)
self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self):
yield from self.bucket.get_batch_iterator()
# subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.bucket.get_batch_count()
return self.bucket.get_batch_count() // self.num_replicas
class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
@ -435,10 +458,21 @@ class EMAModel:
]
def main():
# get device. TODO: support multi-gpu
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
rank = get_rank()
world_size = get_world_size()
torch.cuda.set_device(rank)
if args.hf_token is None:
args.hf_token = os.environ['HF_API_TOKEN']
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
# remove hf_token from args so sneaky people don't steal it from the wandb logs
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
device = torch.device('cuda')
print("DEVICE:", device)
@ -489,18 +523,17 @@ def main():
# load dataset
store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
print(f'STORE_LEN: {len(store)}')
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
num_workers=4,
num_workers=0,
collate_fn=dataset.collate_fn
)
@ -516,6 +549,8 @@ def main():
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
# create ema
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
@ -527,6 +562,7 @@ def main():
global_step = 0
def save_checkpoint():
if rank == 0:
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
@ -541,13 +577,16 @@ def main():
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_path)
# barrier
torch.distributed.barrier()
# train!
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Convert images to latent space
b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@ -571,7 +610,7 @@ def main():
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Backprop
# Backprop and all reduce
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
@ -582,12 +621,27 @@ def main():
if args.use_ema:
ema_unet.step(unet.parameters())
# perf
b_end = time.perf_counter()
seconds_per_step = b_end - b_start
steps_per_second = 1 / seconds_per_step
rank_images_per_second = args.batch_size * steps_per_second
world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args.batch_size * world_size
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"epoch": epoch
"train/loss": loss.detach().item() / world_size,
"train/lr": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/samples_seen": samples_seen,
"perf/rank_samples_per_second": rank_images_per_second,
"perf/global_samples_per_second": world_images_per_second,
}
progress_bar.set_postfix(logs)
run.log(logs)
@ -596,6 +650,7 @@ def main():
save_checkpoint()
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
pipeline = StableDiffusionPipeline(
@ -621,22 +676,17 @@ def main():
# cleanup so we don't run out of memory
del pipeline
gc.collect()
torch.distributed.barrier()
if rank == 0:
save_checkpoint()
torch.distributed.barrier()
cleanup()
print(get_gpu_ram())
print('Done!')
if __name__ == '__main__':
if __name__ == "__main__":
setup()
main()
"""
import numpy as np
# save a sample
img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy()
img = ((img + 1.0) * 127.5).astype(np.uint8)
img = Image.fromarray(img)
img.save('sample.png')
break
"""