waifu-diffusion/trainer/diffusers_trainer.py

804 lines
34 KiB
Python

# Install bitsandbytes:
# `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage:
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import socket
import torch
import torchvision
import transformers
import diffusers
import os
import glob
import random
import tqdm
import resource
import psutil
import pynvml
import wandb
import gc
import time
import itertools
import numpy as np
import json
import re
import traceback
try:
pynvml.nvmlInit()
except pynvml.nvml.NVMLError_LibraryNotFound:
pynvml = None
from typing import Iterable
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image, ImageOps
from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d
torch.backends.cuda.matmul.allow_tf32 = True
# defaults should be good for everyone
# TODO: add custom VAE support. should be simple with diffusers
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--use_ema', type=bool, default=False, help='Use EMA for finetuning')
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=bool, default=False, help='Enable gradient checkpointing')
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=bool, default=False, help='Use 8-bit Adam optimizer')
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]')
parser.add_argument('--lr_scheduler_warmup', type=float, default=0.05, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.')
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.')
parser.add_argument('--shuffle', dest='shuffle', type=bool, default=True, help='Shuffle dataset')
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
args = parser.parse_args()
def setup():
torch.distributed.init_process_group("nccl", init_method="env://")
def cleanup():
torch.distributed.destroy_process_group()
def get_rank() -> int:
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
def get_world_size() -> int:
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
def get_gpu_ram() -> str:
"""
Returns memory usage statistics for the CPU, GPU, and Torch.
:return:
"""
gpu_str = ""
torch_str = ""
try:
cudadev = torch.cuda.current_device()
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev)
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device)
gpu_total = int(gpu_info.total / 1E6)
gpu_free = int(gpu_info.free / 1E6)
gpu_used = int(gpu_info.used / 1E6)
gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \
f"T: {gpu_total:,}mb) "
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6)
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6)
torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6)
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6)
torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \
f"{torch_reserved_max:,}mb, " \
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)"
except AssertionError:
pass
cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 +
resource.getrusage(
resource.RUSAGE_CHILDREN).ru_maxrss / 1E3)
cpu_vmem = psutil.virtual_memory()
cpu_free = int(cpu_vmem.free / 1E6)
return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \
f"{gpu_str}" \
f"{torch_str}"
def _sort_by_ratio(bucket: tuple) -> float:
return bucket[0] / bucket[1]
def _sort_by_area(bucket: tuple) -> float:
return bucket[0] * bucket[1]
class ImageStore:
def __init__(self, data_dir: str) -> None:
self.data_dir = data_dir
self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
def __len__(self) -> int:
return len(self.image_files)
def __valid_file(self, f) -> bool:
try:
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
# get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, ref: Tuple[int, int, int]) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
with open(filename, 'r', encoding='UTF-8') as f:
return f.read()
# ====================================== #
# Bucketing code stolen from hasuwoof: #
# https://github.com/hasuwoof/huskystack #
# ====================================== #
class AspectBucket:
def __init__(self, store: ImageStore,
num_buckets: int,
batch_size: int,
bucket_side_min: int = 256,
bucket_side_max: int = 768,
bucket_side_increment: int = 64,
max_image_area: int = 512 * 768,
max_ratio: float = 2):
self.requested_bucket_count = num_buckets
self.bucket_length_min = bucket_side_min
self.bucket_length_max = bucket_side_max
self.bucket_increment = bucket_side_increment
self.max_image_area = max_image_area
self.batch_size = batch_size
self.total_dropped = 0
if max_ratio <= 0:
self.max_ratio = float('inf')
else:
self.max_ratio = max_ratio
self.store = store
self.buckets = []
self._bucket_ratios = []
self._bucket_interp = None
self.bucket_data: Dict[tuple, List[int]] = dict()
self.init_buckets()
self.fill_buckets()
def init_buckets(self):
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
buckets_by_ratio = {}
# group the buckets by their aspect ratios
for bucket in possible_buckets:
w, h = bucket
# use precision to avoid spooky floats messing up your day
ratio = '{:.4e}'.format(w / h)
if ratio not in buckets_by_ratio:
group = set()
buckets_by_ratio[ratio] = group
else:
group = buckets_by_ratio[ratio]
group.add(bucket)
# now we take the list of buckets we generated and pick the largest by area for each (the first sorted)
# then we put all of those in a list, sorted by the aspect ratio
# the square bucket (LxL) will be the first
unique_ratio_buckets = sorted([sorted(buckets, key=_sort_by_area)[-1]
for buckets in buckets_by_ratio.values()], key=_sort_by_ratio)
# how many buckets to create for each side of the distribution
bucket_count_each = int(np.clip((self.requested_bucket_count + 1) / 2, 1, len(unique_ratio_buckets)))
# we know that the requested_bucket_count must be an odd number, so the indices we calculate
# will include the square bucket and some linearly spaced buckets along the distribution
indices = {*np.linspace(0, len(unique_ratio_buckets) - 1, bucket_count_each, dtype=int)}
# make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio
# here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above
buckets = sorted({*(unique_ratio_buckets[i] for i in indices),
*(tuple(reversed(unique_ratio_buckets[i])) for i in indices)}, key=_sort_by_ratio)
self.buckets = buckets
# cache the bucket ratios and the interpolator that will be used for calculating the best bucket later
# the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio,
# and the output is the bucket index in the self.buckets array
# to find the best fit we can just round that number to get the index
self._bucket_ratios = [w / h for w, h in buckets]
self._bucket_interp = interp1d(self._bucket_ratios, list(range(len(buckets))), assume_sorted=True,
fill_value=None)
for b in buckets:
self.bucket_data[b] = []
def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
"""
Generator that provides batches where the images in a batch fall on the same bucket
Each element generated will be:
(index, w, h)
where each image is an index into the dataset
:return:
"""
max_bucket_len = max(len(b) for b in self.bucket_data.values())
index_schedule = list(range(max_bucket_len))
random.shuffle(index_schedule)
bucket_len_table = {
b: len(self.bucket_data[b]) for b in self.buckets
}
bucket_schedule = []
for i, b in enumerate(self.buckets):
bucket_schedule.extend([i] * (bucket_len_table[b] // self.batch_size))
random.shuffle(bucket_schedule)
bucket_pos = {
b: 0 for b in self.buckets
}
total_generated_by_bucket = {
b: 0 for b in self.buckets
}
for bucket_index in bucket_schedule:
b = self.buckets[bucket_index]
i = bucket_pos[b]
bucket_len = bucket_len_table[b]
batch = []
while len(batch) != self.batch_size:
# advance in the schedule until we find an index that is contained in the bucket
k = index_schedule[i]
if k < bucket_len:
entry = self.bucket_data[b][k]
batch.append(entry)
i += 1
total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i
yield [(idx, *b) for idx in batch]
def fill_buckets(self):
entries = self.store.entries_iterator()
total_dropped = 0
for entry, index in tqdm.tqdm(entries, total=len(self.store)):
if not self._process_entry(entry, index):
total_dropped += 1
for b, values in self.bucket_data.items():
# shuffle the entries for extra randomness and to make sure dropped elements are also random
random.shuffle(values)
# make sure the buckets have an exact number of elements for the batch
to_drop = len(values) % self.batch_size
self.bucket_data[b] = list(values[:len(values) - to_drop])
total_dropped += to_drop
self.total_dropped = total_dropped
def _process_entry(self, entry: Image.Image, index: int) -> bool:
aspect = entry.width / entry.height
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
return False
best_bucket = self._bucket_interp(aspect)
if best_bucket is None:
return False
bucket = self.buckets[round(float(best_bucket))]
self.bucket_data[bucket].append(index)
del entry
return True
class AspectBucketSampler(torch.utils.data.Sampler):
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
super().__init__(None)
self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self):
# subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.bucket.get_batch_count() // self.num_replicas
class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
self.store = store
self.tokenizer = tokenizer
self.ucg = ucg
self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5])
])
def __len__(self):
return len(self.store)
def __getitem__(self, item: Tuple[int, int, int]):
return_dict = {'pixel_values': None, 'input_ids': None}
image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg:
caption_file = self.store.get_caption(item)
else:
caption_file = ''
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
return return_dict
def collate_fn(self, examples):
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example['input_ids'] for example in examples if example is not None]
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True)
return {
'pixel_values': pixel_values,
'input_ids': padded_tokens.input_ids,
'attention_mask': padded_tokens.attention_mask,
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
self.decay = decay
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)
@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)
self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
s_param.copy_(param)
torch.cuda.empty_cache()
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
# From CompVis LitEMA implementation
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
del self.collected_params
gc.collect()
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def main():
rank = get_rank()
world_size = get_world_size()
torch.cuda.set_device(rank)
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
# Inform the user of host, and various versions -- useful for debugging issues.
print("RUN_NAME:", args.run_name)
print("HOST:", socket.gethostname())
print("CUDA:", torch.version.cuda)
print("TORCH:", torch.__version__)
print("TRANSFORMERS:", transformers.__version__)
print("DIFFUSERS:", diffusers.__version__)
print("MODEL:", args.model)
print("FP16:", args.fp16)
print("RESOLUTION:", args.resolution)
if args.hf_token is None:
args.hf_token = os.environ['HF_API_TOKEN']
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
device = torch.device('cuda')
print("DEVICE:", device)
# setup fp16 stuff
scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
# Set seed
torch.manual_seed(args.seed)
print('RANDOM SEED:', args.seed)
if args.resume:
args.model = args.resume
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token)
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token)
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)
unet = UNet2DConditionModel.from_pretrained(args.model, subfolder='unet', use_auth_token=args.hf_token)
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try:
import bitsandbytes as bnb
optimizer_cls = bnb.optim.AdamW8bit
except:
print('bitsandbytes not supported, using regular Adam optimizer')
optimizer_cls = torch.optim.AdamW
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
unet.parameters(),
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
)
# load dataset
store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
print(f'STORE_LEN: {len(store)}')
if args.output_bucket_info:
print(bucket.get_bucket_info())
exit(0)
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
num_workers=0,
collate_fn=dataset.collate_fn
)
weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
# create ema
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
print(get_gpu_ram())
num_steps_per_epoch = len(train_dataloader)
progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False)
global_step = 0
if args.resume:
target_global_step = int(args.resume.split('_')[-1])
print(f'resuming from {args.resume}...')
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=int(args.lr_scheduler_warmup * num_steps_per_epoch * args.epochs),
num_training_steps=args.epochs * num_steps_per_epoch,
#last_epoch=(global_step // num_steps_per_epoch) - 1,
)
def save_checkpoint(global_step):
if rank == 0:
if args.use_ema:
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
if args.use_ema:
ema_unet.restore(unet.parameters())
# barrier
torch.distributed.barrier()
# train!
try:
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs):
unet.train()
for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step:
if rank == 0:
progress_bar.update(1)
global_step += 1
continue
b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
if args.clip_penultimate:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Backprop and all reduce
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
# Update EMA
if args.use_ema:
ema_unet.step(unet.parameters())
# perf
b_end = time.perf_counter()
seconds_per_step = b_end - b_start
steps_per_second = 1 / seconds_per_step
rank_images_per_second = args.batch_size * steps_per_second
world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args.batch_size * world_size
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
logs = {
"train/loss": loss.detach().item() / world_size,
"train/lr": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/step": global_step,
"train/samples_seen": samples_seen,
"perf/rank_samples_per_second": rank_images_per_second,
"perf/global_samples_per_second": world_images_per_second,
}
progress_bar.set_postfix(logs)
run.log(logs, step=global_step)
if global_step % args.save_steps == 0:
save_checkpoint(global_step)
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=None, # disable safety checker to save memory
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to(device)
# inference
images = []
with torch.no_grad():
with torch.autocast('cuda', enabled=args.fp16):
for _ in range(args.image_log_amount):
images.append(
wandb.Image(pipeline(
prompt, num_inference_steps=args.image_log_inference_steps
).images[0],
caption=prompt)
)
# log images under single caption
run.log({'images': images}, step=global_step)
# cleanup so we don't run out of memory
del pipeline
gc.collect()
torch.distributed.barrier()
except Exception as e:
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
pass
save_checkpoint(global_step)
torch.distributed.barrier()
cleanup()
print(get_gpu_ram())
print('Done!')
if __name__ == "__main__":
setup()
main()