Implement distributed training
This commit is contained in:
parent
ef25ccac36
commit
1a94c70736
|
@ -1,3 +1,6 @@
|
|||
# Example Usage:
|
||||
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||
|
||||
import argparse
|
||||
import socket
|
||||
import torch
|
||||
|
@ -13,6 +16,7 @@ import psutil
|
|||
import pynvml
|
||||
import wandb
|
||||
import gc
|
||||
import time
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
|
@ -31,6 +35,8 @@ from PIL import Image
|
|||
from typing import Dict, List, Generator, Tuple
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# defaults should be good for everyone
|
||||
# TODO: add custom VAE support. should be simple with diffusers
|
||||
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
||||
|
@ -62,11 +68,21 @@ parser.add_argument('--image_log_steps', type=int, default=100, help='Number of
|
|||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
def setup():
|
||||
torch.distributed.init_process_group("nccl", init_method="env://")
|
||||
|
||||
# remove hf_token from args so sneaky people don't steal it from the wandb logs
|
||||
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
|
||||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||
def cleanup():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
def get_rank() -> int:
|
||||
if not torch.distributed.is_initialized():
|
||||
return 0
|
||||
return torch.distributed.get_rank()
|
||||
|
||||
def get_world_size() -> int:
|
||||
if not torch.distributed.is_initialized():
|
||||
return 1
|
||||
return torch.distributed.get_world_size()
|
||||
|
||||
# Inform the user of host, and various versions -- useful for debugging isseus.
|
||||
print("RUN_NAME:", args.run_name)
|
||||
|
@ -322,18 +338,25 @@ class AspectBucket:
|
|||
|
||||
self.bucket_data[bucket].append(index)
|
||||
|
||||
del entry
|
||||
|
||||
return True
|
||||
|
||||
class AspectBucketSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, bucket: AspectBucket):
|
||||
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
|
||||
super().__init__(None)
|
||||
self.bucket = bucket
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.bucket.get_batch_iterator()
|
||||
# subsample the bucket to only include the elements that are assigned to this rank
|
||||
indices = self.bucket.get_batch_iterator()
|
||||
indices = list(indices)[self.rank::self.num_replicas]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.bucket.get_batch_count()
|
||||
return self.bucket.get_batch_count() // self.num_replicas
|
||||
|
||||
class AspectDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
|
||||
|
@ -435,10 +458,21 @@ class EMAModel:
|
|||
]
|
||||
|
||||
def main():
|
||||
# get device. TODO: support multi-gpu
|
||||
device = 'cpu'
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
rank = get_rank()
|
||||
world_size = get_world_size()
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if args.hf_token is None:
|
||||
args.hf_token = os.environ['HF_API_TOKEN']
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
# remove hf_token from args so sneaky people don't steal it from the wandb logs
|
||||
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
|
||||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
print("DEVICE:", device)
|
||||
|
||||
|
@ -489,18 +523,17 @@ def main():
|
|||
|
||||
# load dataset
|
||||
|
||||
|
||||
store = ImageStore(args.dataset)
|
||||
dataset = AspectDataset(store, tokenizer)
|
||||
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||
sampler = AspectBucketSampler(bucket)
|
||||
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
||||
|
||||
print(f'STORE_LEN: {len(store)}')
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=4,
|
||||
num_workers=0,
|
||||
collate_fn=dataset.collate_fn
|
||||
)
|
||||
|
||||
|
@ -516,6 +549,8 @@ def main():
|
|||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
|
||||
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||
|
||||
# create ema
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet.parameters())
|
||||
|
@ -527,6 +562,7 @@ def main():
|
|||
global_step = 0
|
||||
|
||||
def save_checkpoint():
|
||||
if rank == 0:
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
pipeline = StableDiffusionPipeline(
|
||||
|
@ -541,13 +577,16 @@ def main():
|
|||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
pipeline.save_pretrained(args.output_path)
|
||||
# barrier
|
||||
torch.distributed.barrier()
|
||||
|
||||
# train!
|
||||
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
|
||||
for epoch in range(args.epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Convert images to latent space
|
||||
b_start = time.perf_counter()
|
||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
|
@ -571,7 +610,7 @@ def main():
|
|||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Backprop
|
||||
# Backprop and all reduce
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
@ -582,12 +621,27 @@ def main():
|
|||
if args.use_ema:
|
||||
ema_unet.step(unet.parameters())
|
||||
|
||||
# perf
|
||||
b_end = time.perf_counter()
|
||||
seconds_per_step = b_end - b_start
|
||||
steps_per_second = 1 / seconds_per_step
|
||||
rank_images_per_second = args.batch_size * steps_per_second
|
||||
world_images_per_second = rank_images_per_second * world_size
|
||||
samples_seen = global_step * args.batch_size * world_size
|
||||
|
||||
# All reduce loss
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"epoch": epoch
|
||||
"train/loss": loss.detach().item() / world_size,
|
||||
"train/lr": lr_scheduler.get_last_lr()[0],
|
||||
"train/epoch": epoch,
|
||||
"train/samples_seen": samples_seen,
|
||||
"perf/rank_samples_per_second": rank_images_per_second,
|
||||
"perf/global_samples_per_second": world_images_per_second,
|
||||
}
|
||||
progress_bar.set_postfix(logs)
|
||||
run.log(logs)
|
||||
|
@ -596,6 +650,7 @@ def main():
|
|||
save_checkpoint()
|
||||
|
||||
if global_step % args.image_log_steps == 0:
|
||||
if rank == 0:
|
||||
# get prompt from random batch
|
||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||
pipeline = StableDiffusionPipeline(
|
||||
|
@ -621,22 +676,17 @@ def main():
|
|||
# cleanup so we don't run out of memory
|
||||
del pipeline
|
||||
gc.collect()
|
||||
torch.distributed.barrier()
|
||||
|
||||
if rank == 0:
|
||||
save_checkpoint()
|
||||
|
||||
torch.distributed.barrier()
|
||||
cleanup()
|
||||
|
||||
print(get_gpu_ram())
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
setup()
|
||||
main()
|
||||
|
||||
"""
|
||||
import numpy as np
|
||||
# save a sample
|
||||
img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy()
|
||||
img = ((img + 1.0) * 127.5).astype(np.uint8)
|
||||
img = Image.fromarray(img)
|
||||
img.save('sample.png')
|
||||
break
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue