Implement distributed training

This commit is contained in:
Anthony Mercurio 2022-11-01 13:28:12 -07:00 committed by GitHub
parent ef25ccac36
commit 1a94c70736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 127 additions and 77 deletions

View File

@ -1,3 +1,6 @@
# Example Usage:
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse import argparse
import socket import socket
import torch import torch
@ -13,6 +16,7 @@ import psutil
import pynvml import pynvml
import wandb import wandb
import gc import gc
import time
import itertools import itertools
import numpy as np import numpy as np
@ -31,6 +35,8 @@ from PIL import Image
from typing import Dict, List, Generator, Tuple from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
torch.backends.cuda.matmul.allow_tf32 = True
# defaults should be good for everyone # defaults should be good for everyone
# TODO: add custom VAE support. should be simple with diffusers # TODO: add custom VAE support. should be simple with diffusers
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner') parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
@ -62,11 +68,21 @@ parser.add_argument('--image_log_steps', type=int, default=100, help='Number of
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) def setup():
torch.distributed.init_process_group("nccl", init_method="env://")
# remove hf_token from args so sneaky people don't steal it from the wandb logs def cleanup():
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']} torch.distributed.destroy_process_group()
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
def get_rank() -> int:
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
def get_world_size() -> int:
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
# Inform the user of host, and various versions -- useful for debugging isseus. # Inform the user of host, and various versions -- useful for debugging isseus.
print("RUN_NAME:", args.run_name) print("RUN_NAME:", args.run_name)
@ -322,18 +338,25 @@ class AspectBucket:
self.bucket_data[bucket].append(index) self.bucket_data[bucket].append(index)
del entry
return True return True
class AspectBucketSampler(torch.utils.data.Sampler): class AspectBucketSampler(torch.utils.data.Sampler):
def __init__(self, bucket: AspectBucket): def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
super().__init__(None) super().__init__(None)
self.bucket = bucket self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self): def __iter__(self):
yield from self.bucket.get_batch_iterator() # subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self): def __len__(self):
return self.bucket.get_batch_count() return self.bucket.get_batch_count() // self.num_replicas
class AspectDataset(torch.utils.data.Dataset): class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1): def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
@ -435,10 +458,21 @@ class EMAModel:
] ]
def main(): def main():
# get device. TODO: support multi-gpu rank = get_rank()
device = 'cpu' world_size = get_world_size()
if torch.cuda.is_available(): torch.cuda.set_device(rank)
device = 'cuda'
if args.hf_token is None:
args.hf_token = os.environ['HF_API_TOKEN']
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
# remove hf_token from args so sneaky people don't steal it from the wandb logs
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
device = torch.device('cuda')
print("DEVICE:", device) print("DEVICE:", device)
@ -489,18 +523,17 @@ def main():
# load dataset # load dataset
store = ImageStore(args.dataset) store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer) dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket) sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
print(f'STORE_LEN: {len(store)}') print(f'STORE_LEN: {len(store)}')
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
dataset, dataset,
batch_sampler=sampler, batch_sampler=sampler,
num_workers=4, num_workers=0,
collate_fn=dataset.collate_fn collate_fn=dataset.collate_fn
) )
@ -516,6 +549,8 @@ def main():
unet = unet.to(device, dtype=torch.float32) unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype) text_encoder = text_encoder.to(device, dtype=weight_dtype)
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
# create ema # create ema
if args.use_ema: if args.use_ema:
ema_unet = EMAModel(unet.parameters()) ema_unet = EMAModel(unet.parameters())
@ -527,27 +562,31 @@ def main():
global_step = 0 global_step = 0
def save_checkpoint(): def save_checkpoint():
if args.use_ema: if rank == 0:
ema_unet.copy_to(unet.parameters()) if args.use_ema:
pipeline = StableDiffusionPipeline( ema_unet.copy_to(unet.parameters())
text_encoder=text_encoder, pipeline = StableDiffusionPipeline(
vae=vae, text_encoder=text_encoder,
unet=unet, vae=vae,
tokenizer=tokenizer, unet=unet,
scheduler=PNDMScheduler( tokenizer=tokenizer,
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True scheduler=PNDMScheduler(
), beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), ),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
) feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
pipeline.save_pretrained(args.output_path) )
pipeline.save_pretrained(args.output_path)
# barrier
torch.distributed.barrier()
# train! # train!
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs): for epoch in range(args.epochs):
unet.train() unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Convert images to latent space b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
@ -571,7 +610,7 @@ def main():
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Backprop # Backprop and all reduce
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
@ -582,61 +621,72 @@ def main():
if args.use_ema: if args.use_ema:
ema_unet.step(unet.parameters()) ema_unet.step(unet.parameters())
progress_bar.update(1) # perf
global_step += 1 b_end = time.perf_counter()
logs = { seconds_per_step = b_end - b_start
"loss": loss.detach().item(), steps_per_second = 1 / seconds_per_step
"lr": lr_scheduler.get_last_lr()[0], rank_images_per_second = args.batch_size * steps_per_second
"epoch": epoch world_images_per_second = rank_images_per_second * world_size
} samples_seen = global_step * args.batch_size * world_size
progress_bar.set_postfix(logs)
run.log(logs) # All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
logs = {
"train/loss": loss.detach().item() / world_size,
"train/lr": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/samples_seen": samples_seen,
"perf/rank_samples_per_second": rank_images_per_second,
"perf/global_samples_per_second": world_images_per_second,
}
progress_bar.set_postfix(logs)
run.log(logs)
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
save_checkpoint() save_checkpoint()
if global_step % args.image_log_steps == 0: if global_step % args.image_log_steps == 0:
# get prompt from random batch if rank == 0:
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) # get prompt from random batch
pipeline = StableDiffusionPipeline( prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
text_encoder=text_encoder, pipeline = StableDiffusionPipeline(
vae=vae, text_encoder=text_encoder,
unet=unet, vae=vae,
tokenizer=tokenizer, unet=unet,
scheduler=PNDMScheduler( tokenizer=tokenizer,
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True scheduler=PNDMScheduler(
), beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
safety_checker=None, # display safety checker to save memory ),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), safety_checker=None, # display safety checker to save memory
).to(device) feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
# inference ).to(device)
images = [] # inference
with torch.no_grad(): images = []
with torch.autocast('cuda', enabled=args.fp16): with torch.no_grad():
for _ in range(args.image_log_amount): with torch.autocast('cuda', enabled=args.fp16):
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) for _ in range(args.image_log_amount):
# log images under single caption images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
run.log({'images': images}) # log images under single caption
run.log({'images': images})
# cleanup so we don't run out of memory # cleanup so we don't run out of memory
del pipeline del pipeline
gc.collect() gc.collect()
torch.distributed.barrier()
save_checkpoint() if rank == 0:
save_checkpoint()
torch.distributed.barrier()
cleanup()
print(get_gpu_ram()) print(get_gpu_ram())
print('Done!') print('Done!')
if __name__ == "__main__":
if __name__ == '__main__': setup()
main() main()
"""
import numpy as np
# save a sample
img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy()
img = ((img + 1.0) * 127.5).astype(np.uint8)
img = Image.fromarray(img)
img.save('sample.png')
break
"""