Implement distributed training
This commit is contained in:
parent
ef25ccac36
commit
1a94c70736
|
@ -1,3 +1,6 @@
|
||||||
|
# Example Usage:
|
||||||
|
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import socket
|
import socket
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,6 +16,7 @@ import psutil
|
||||||
import pynvml
|
import pynvml
|
||||||
import wandb
|
import wandb
|
||||||
import gc
|
import gc
|
||||||
|
import time
|
||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -31,6 +35,8 @@ from PIL import Image
|
||||||
from typing import Dict, List, Generator, Tuple
|
from typing import Dict, List, Generator, Tuple
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# defaults should be good for everyone
|
# defaults should be good for everyone
|
||||||
# TODO: add custom VAE support. should be simple with diffusers
|
# TODO: add custom VAE support. should be simple with diffusers
|
||||||
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
||||||
|
@ -62,11 +68,21 @@ parser.add_argument('--image_log_steps', type=int, default=100, help='Number of
|
||||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
def setup():
|
||||||
|
torch.distributed.init_process_group("nccl", init_method="env://")
|
||||||
|
|
||||||
# remove hf_token from args so sneaky people don't steal it from the wandb logs
|
def cleanup():
|
||||||
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
|
torch.distributed.destroy_process_group()
|
||||||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
|
||||||
|
def get_rank() -> int:
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
return 0
|
||||||
|
return torch.distributed.get_rank()
|
||||||
|
|
||||||
|
def get_world_size() -> int:
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
return 1
|
||||||
|
return torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Inform the user of host, and various versions -- useful for debugging isseus.
|
# Inform the user of host, and various versions -- useful for debugging isseus.
|
||||||
print("RUN_NAME:", args.run_name)
|
print("RUN_NAME:", args.run_name)
|
||||||
|
@ -322,18 +338,25 @@ class AspectBucket:
|
||||||
|
|
||||||
self.bucket_data[bucket].append(index)
|
self.bucket_data[bucket].append(index)
|
||||||
|
|
||||||
|
del entry
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class AspectBucketSampler(torch.utils.data.Sampler):
|
class AspectBucketSampler(torch.utils.data.Sampler):
|
||||||
def __init__(self, bucket: AspectBucket):
|
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
|
||||||
super().__init__(None)
|
super().__init__(None)
|
||||||
self.bucket = bucket
|
self.bucket = bucket
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
yield from self.bucket.get_batch_iterator()
|
# subsample the bucket to only include the elements that are assigned to this rank
|
||||||
|
indices = self.bucket.get_batch_iterator()
|
||||||
|
indices = list(indices)[self.rank::self.num_replicas]
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.bucket.get_batch_count()
|
return self.bucket.get_batch_count() // self.num_replicas
|
||||||
|
|
||||||
class AspectDataset(torch.utils.data.Dataset):
|
class AspectDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
|
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
|
||||||
|
@ -435,10 +458,21 @@ class EMAModel:
|
||||||
]
|
]
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# get device. TODO: support multi-gpu
|
rank = get_rank()
|
||||||
device = 'cpu'
|
world_size = get_world_size()
|
||||||
if torch.cuda.is_available():
|
torch.cuda.set_device(rank)
|
||||||
device = 'cuda'
|
|
||||||
|
if args.hf_token is None:
|
||||||
|
args.hf_token = os.environ['HF_API_TOKEN']
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
||||||
|
# remove hf_token from args so sneaky people don't steal it from the wandb logs
|
||||||
|
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
|
||||||
|
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||||
|
|
||||||
|
device = torch.device('cuda')
|
||||||
|
|
||||||
print("DEVICE:", device)
|
print("DEVICE:", device)
|
||||||
|
|
||||||
|
@ -489,18 +523,17 @@ def main():
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
|
|
||||||
|
|
||||||
store = ImageStore(args.dataset)
|
store = ImageStore(args.dataset)
|
||||||
dataset = AspectDataset(store, tokenizer)
|
dataset = AspectDataset(store, tokenizer)
|
||||||
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||||
sampler = AspectBucketSampler(bucket)
|
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
print(f'STORE_LEN: {len(store)}')
|
print(f'STORE_LEN: {len(store)}')
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_sampler=sampler,
|
batch_sampler=sampler,
|
||||||
num_workers=4,
|
num_workers=0,
|
||||||
collate_fn=dataset.collate_fn
|
collate_fn=dataset.collate_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -516,6 +549,8 @@ def main():
|
||||||
unet = unet.to(device, dtype=torch.float32)
|
unet = unet.to(device, dtype=torch.float32)
|
||||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||||
|
|
||||||
# create ema
|
# create ema
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet = EMAModel(unet.parameters())
|
ema_unet = EMAModel(unet.parameters())
|
||||||
|
@ -527,27 +562,31 @@ def main():
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
def save_checkpoint():
|
def save_checkpoint():
|
||||||
if args.use_ema:
|
if rank == 0:
|
||||||
ema_unet.copy_to(unet.parameters())
|
if args.use_ema:
|
||||||
pipeline = StableDiffusionPipeline(
|
ema_unet.copy_to(unet.parameters())
|
||||||
text_encoder=text_encoder,
|
pipeline = StableDiffusionPipeline(
|
||||||
vae=vae,
|
text_encoder=text_encoder,
|
||||||
unet=unet,
|
vae=vae,
|
||||||
tokenizer=tokenizer,
|
unet=unet,
|
||||||
scheduler=PNDMScheduler(
|
tokenizer=tokenizer,
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
scheduler=PNDMScheduler(
|
||||||
),
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
)
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
pipeline.save_pretrained(args.output_path)
|
)
|
||||||
|
pipeline.save_pretrained(args.output_path)
|
||||||
|
# barrier
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
# train!
|
# train!
|
||||||
|
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
unet.train()
|
unet.train()
|
||||||
train_loss = 0.0
|
train_loss = 0.0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
# Convert images to latent space
|
b_start = time.perf_counter()
|
||||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
|
|
||||||
|
@ -571,7 +610,7 @@ def main():
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||||
|
|
||||||
# Backprop
|
# Backprop and all reduce
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
@ -582,61 +621,72 @@ def main():
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet.step(unet.parameters())
|
ema_unet.step(unet.parameters())
|
||||||
|
|
||||||
progress_bar.update(1)
|
# perf
|
||||||
global_step += 1
|
b_end = time.perf_counter()
|
||||||
logs = {
|
seconds_per_step = b_end - b_start
|
||||||
"loss": loss.detach().item(),
|
steps_per_second = 1 / seconds_per_step
|
||||||
"lr": lr_scheduler.get_last_lr()[0],
|
rank_images_per_second = args.batch_size * steps_per_second
|
||||||
"epoch": epoch
|
world_images_per_second = rank_images_per_second * world_size
|
||||||
}
|
samples_seen = global_step * args.batch_size * world_size
|
||||||
progress_bar.set_postfix(logs)
|
|
||||||
run.log(logs)
|
# All reduce loss
|
||||||
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
logs = {
|
||||||
|
"train/loss": loss.detach().item() / world_size,
|
||||||
|
"train/lr": lr_scheduler.get_last_lr()[0],
|
||||||
|
"train/epoch": epoch,
|
||||||
|
"train/samples_seen": samples_seen,
|
||||||
|
"perf/rank_samples_per_second": rank_images_per_second,
|
||||||
|
"perf/global_samples_per_second": world_images_per_second,
|
||||||
|
}
|
||||||
|
progress_bar.set_postfix(logs)
|
||||||
|
run.log(logs)
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
save_checkpoint()
|
save_checkpoint()
|
||||||
|
|
||||||
if global_step % args.image_log_steps == 0:
|
if global_step % args.image_log_steps == 0:
|
||||||
# get prompt from random batch
|
if rank == 0:
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
# get prompt from random batch
|
||||||
pipeline = StableDiffusionPipeline(
|
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||||
text_encoder=text_encoder,
|
pipeline = StableDiffusionPipeline(
|
||||||
vae=vae,
|
text_encoder=text_encoder,
|
||||||
unet=unet,
|
vae=vae,
|
||||||
tokenizer=tokenizer,
|
unet=unet,
|
||||||
scheduler=PNDMScheduler(
|
tokenizer=tokenizer,
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
scheduler=PNDMScheduler(
|
||||||
),
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
safety_checker=None, # display safety checker to save memory
|
),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
safety_checker=None, # display safety checker to save memory
|
||||||
).to(device)
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
# inference
|
).to(device)
|
||||||
images = []
|
# inference
|
||||||
with torch.no_grad():
|
images = []
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.no_grad():
|
||||||
for _ in range(args.image_log_amount):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
for _ in range(args.image_log_amount):
|
||||||
# log images under single caption
|
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
||||||
run.log({'images': images})
|
# log images under single caption
|
||||||
|
run.log({'images': images})
|
||||||
|
|
||||||
# cleanup so we don't run out of memory
|
# cleanup so we don't run out of memory
|
||||||
del pipeline
|
del pipeline
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
save_checkpoint()
|
if rank == 0:
|
||||||
|
save_checkpoint()
|
||||||
|
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup()
|
||||||
|
|
||||||
print(get_gpu_ram())
|
print(get_gpu_ram())
|
||||||
print('Done!')
|
print('Done!')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == '__main__':
|
setup()
|
||||||
main()
|
main()
|
||||||
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
# save a sample
|
|
||||||
img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy()
|
|
||||||
img = ((img + 1.0) * 127.5).astype(np.uint8)
|
|
||||||
img = Image.fromarray(img)
|
|
||||||
img.save('sample.png')
|
|
||||||
break
|
|
||||||
"""
|
|
||||||
|
|
Loading…
Reference in New Issue