Aspect ratio bucketing, Training overhaul, more new stuff
This commit is contained in:
parent
76efa00de3
commit
be702cce2e
|
@ -0,0 +1,636 @@
|
||||||
|
import argparse
|
||||||
|
import socket
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import transformers
|
||||||
|
import diffusers
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
import resource
|
||||||
|
import psutil
|
||||||
|
import pynvml
|
||||||
|
import wandb
|
||||||
|
import gc
|
||||||
|
import itertools
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
except pynvml.nvml.NVMLError_LibraryNotFound:
|
||||||
|
pynvml = None
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from typing import Dict, List, Generator, Tuple
|
||||||
|
from scipy.interpolate import interp1d
|
||||||
|
|
||||||
|
# defaults should be good for everyone
|
||||||
|
# TODO: add custom VAE support. should be simple with diffusers
|
||||||
|
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
||||||
|
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
||||||
|
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
|
||||||
|
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
|
||||||
|
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
|
||||||
|
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
|
||||||
|
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
|
||||||
|
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
||||||
|
parser.add_argument('--use_ema', type=bool, default=False, help='Use EMA for finetuning')
|
||||||
|
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability
|
||||||
|
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=bool, default=False, help='Enable gradient checkpointing')
|
||||||
|
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=bool, default=False, help='Use 8-bit Adam optimizer')
|
||||||
|
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
|
||||||
|
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
|
||||||
|
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
|
||||||
|
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
|
||||||
|
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
|
||||||
|
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
|
||||||
|
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
|
||||||
|
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.')
|
||||||
|
parser.add_argument('--shuffle', dest='shuffle', type=bool, default=True, help='Shuffle dataset')
|
||||||
|
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
|
||||||
|
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
|
||||||
|
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
|
||||||
|
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||||
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
||||||
|
# remove hf_token from args so sneaky people don't steal it from the wandb logs
|
||||||
|
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
|
||||||
|
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
||||||
|
|
||||||
|
# Inform the user of host, and various versions -- useful for debugging isseus.
|
||||||
|
print("RUN_NAME:", args.run_name)
|
||||||
|
print("HOST:", socket.gethostname())
|
||||||
|
print("CUDA:", torch.version.cuda)
|
||||||
|
print("TORCH:", torch.__version__)
|
||||||
|
print("TRANSFORMERS:", transformers.__version__)
|
||||||
|
print("DIFFUSERS:", diffusers.__version__)
|
||||||
|
print("MODEL:", args.model)
|
||||||
|
print("FP16:", args.fp16)
|
||||||
|
print("RESOLUTION:", args.resolution)
|
||||||
|
|
||||||
|
def get_gpu_ram() -> str:
|
||||||
|
"""
|
||||||
|
Returns memory usage statistics for the CPU, GPU, and Torch.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
gpu_str = ""
|
||||||
|
torch_str = ""
|
||||||
|
try:
|
||||||
|
cudadev = torch.cuda.current_device()
|
||||||
|
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev)
|
||||||
|
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device)
|
||||||
|
gpu_total = int(gpu_info.total / 1E6)
|
||||||
|
gpu_free = int(gpu_info.free / 1E6)
|
||||||
|
gpu_used = int(gpu_info.used / 1E6)
|
||||||
|
gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \
|
||||||
|
f"T: {gpu_total:,}mb) "
|
||||||
|
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6)
|
||||||
|
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6)
|
||||||
|
torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6)
|
||||||
|
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6)
|
||||||
|
torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \
|
||||||
|
f"{torch_reserved_max:,}mb, " \
|
||||||
|
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)"
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 +
|
||||||
|
resource.getrusage(
|
||||||
|
resource.RUSAGE_CHILDREN).ru_maxrss / 1E3)
|
||||||
|
cpu_vmem = psutil.virtual_memory()
|
||||||
|
cpu_free = int(cpu_vmem.free / 1E6)
|
||||||
|
return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \
|
||||||
|
f"{gpu_str}" \
|
||||||
|
f"{torch_str}"
|
||||||
|
|
||||||
|
def _sort_by_ratio(bucket: tuple) -> float:
|
||||||
|
return bucket[0] / bucket[1]
|
||||||
|
|
||||||
|
def _sort_by_area(bucket: tuple) -> float:
|
||||||
|
return bucket[0] * bucket[1]
|
||||||
|
|
||||||
|
class ImageStore:
|
||||||
|
def __init__(self, data_dir: str) -> None:
|
||||||
|
self.data_dir = data_dir
|
||||||
|
|
||||||
|
self.image_files = []
|
||||||
|
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.image_files)
|
||||||
|
|
||||||
|
# iterator returns images as PIL images and their index in the store
|
||||||
|
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||||
|
for f in range(len(self)):
|
||||||
|
yield Image.open(self.image_files[f]), f
|
||||||
|
|
||||||
|
# get image by index
|
||||||
|
def get_image(self, index: int) -> Image.Image:
|
||||||
|
return Image.open(self.image_files[index])
|
||||||
|
|
||||||
|
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||||
|
def get_caption(self, index: int) -> str:
|
||||||
|
filename = self.image_files[index].split('.')[0] + '.txt'
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
class AspectBucket:
|
||||||
|
def __init__(self, store: ImageStore,
|
||||||
|
num_buckets: int,
|
||||||
|
batch_size: int,
|
||||||
|
bucket_side_min: int = 256,
|
||||||
|
bucket_side_max: int = 768,
|
||||||
|
bucket_side_increment: int = 64,
|
||||||
|
max_image_area: int = 512 * 768,
|
||||||
|
max_ratio: float = 2):
|
||||||
|
|
||||||
|
self.requested_bucket_count = num_buckets
|
||||||
|
self.bucket_length_min = bucket_side_min
|
||||||
|
self.bucket_length_max = bucket_side_max
|
||||||
|
self.bucket_increment = bucket_side_increment
|
||||||
|
self.max_image_area = max_image_area
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.total_dropped = 0
|
||||||
|
|
||||||
|
if max_ratio <= 0:
|
||||||
|
self.max_ratio = float('inf')
|
||||||
|
else:
|
||||||
|
self.max_ratio = max_ratio
|
||||||
|
|
||||||
|
self.store = store
|
||||||
|
self.buckets = []
|
||||||
|
self._bucket_ratios = []
|
||||||
|
self._bucket_interp = None
|
||||||
|
self.bucket_data: Dict[tuple, List[int]] = dict()
|
||||||
|
self.init_buckets()
|
||||||
|
self.fill_buckets()
|
||||||
|
|
||||||
|
def init_buckets(self):
|
||||||
|
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
|
||||||
|
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
|
||||||
|
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
|
||||||
|
|
||||||
|
buckets_by_ratio = {}
|
||||||
|
|
||||||
|
# group the buckets by their aspect ratios
|
||||||
|
for bucket in possible_buckets:
|
||||||
|
w, h = bucket
|
||||||
|
# use precision to avoid spooky floats messing up your day
|
||||||
|
ratio = '{:.4e}'.format(w / h)
|
||||||
|
|
||||||
|
if ratio not in buckets_by_ratio:
|
||||||
|
group = set()
|
||||||
|
buckets_by_ratio[ratio] = group
|
||||||
|
else:
|
||||||
|
group = buckets_by_ratio[ratio]
|
||||||
|
|
||||||
|
group.add(bucket)
|
||||||
|
|
||||||
|
# now we take the list of buckets we generated and pick the largest by area for each (the first sorted)
|
||||||
|
# then we put all of those in a list, sorted by the aspect ratio
|
||||||
|
# the square bucket (LxL) will be the first
|
||||||
|
unique_ratio_buckets = sorted([sorted(buckets, key=_sort_by_area)[-1]
|
||||||
|
for buckets in buckets_by_ratio.values()], key=_sort_by_ratio)
|
||||||
|
|
||||||
|
# how many buckets to create for each side of the distribution
|
||||||
|
bucket_count_each = int(np.clip((self.requested_bucket_count + 1) / 2, 1, len(unique_ratio_buckets)))
|
||||||
|
|
||||||
|
# we know that the requested_bucket_count must be an odd number, so the indices we calculate
|
||||||
|
# will include the square bucket and some linearly spaced buckets along the distribution
|
||||||
|
indices = {*np.linspace(0, len(unique_ratio_buckets) - 1, bucket_count_each, dtype=int)}
|
||||||
|
|
||||||
|
# make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio
|
||||||
|
# here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above
|
||||||
|
buckets = sorted({*(unique_ratio_buckets[i] for i in indices),
|
||||||
|
*(tuple(reversed(unique_ratio_buckets[i])) for i in indices)}, key=_sort_by_ratio)
|
||||||
|
|
||||||
|
self.buckets = buckets
|
||||||
|
|
||||||
|
# cache the bucket ratios and the interpolator that will be used for calculating the best bucket later
|
||||||
|
# the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio,
|
||||||
|
# and the output is the bucket index in the self.buckets array
|
||||||
|
# to find the best fit we can just round that number to get the index
|
||||||
|
self._bucket_ratios = [w / h for w, h in buckets]
|
||||||
|
self._bucket_interp = interp1d(self._bucket_ratios, list(range(len(buckets))), assume_sorted=True,
|
||||||
|
fill_value=None)
|
||||||
|
|
||||||
|
for b in buckets:
|
||||||
|
self.bucket_data[b] = []
|
||||||
|
|
||||||
|
def get_batch_count(self):
|
||||||
|
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
|
||||||
|
|
||||||
|
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
||||||
|
"""
|
||||||
|
Generator that provides batches where the images in a batch fall on the same bucket
|
||||||
|
|
||||||
|
Each element generated will be:
|
||||||
|
((w, h), [image1, image2, ..., image{batch_size}])
|
||||||
|
|
||||||
|
where each image is an index into the dataset
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
max_bucket_len = max(len(b) for b in self.bucket_data.values())
|
||||||
|
index_schedule = list(range(max_bucket_len))
|
||||||
|
random.shuffle(index_schedule)
|
||||||
|
|
||||||
|
bucket_len_table = {
|
||||||
|
b: len(self.bucket_data[b]) for b in self.buckets
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket_schedule = []
|
||||||
|
for i, b in enumerate(self.buckets):
|
||||||
|
bucket_schedule.extend([i] * (bucket_len_table[b] // self.batch_size))
|
||||||
|
|
||||||
|
random.shuffle(bucket_schedule)
|
||||||
|
|
||||||
|
bucket_pos = {
|
||||||
|
b: 0 for b in self.buckets
|
||||||
|
}
|
||||||
|
|
||||||
|
total_generated_by_bucket = {
|
||||||
|
b: 0 for b in self.buckets
|
||||||
|
}
|
||||||
|
|
||||||
|
for bucket_index in bucket_schedule:
|
||||||
|
b = self.buckets[bucket_index]
|
||||||
|
i = bucket_pos[b]
|
||||||
|
bucket_len = bucket_len_table[b]
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
while len(batch) != self.batch_size:
|
||||||
|
# advance in the schedule until we find an index that is contained in the bucket
|
||||||
|
k = index_schedule[i]
|
||||||
|
if k < bucket_len:
|
||||||
|
entry = self.bucket_data[b][k]
|
||||||
|
batch.append(entry)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
total_generated_by_bucket[b] += self.batch_size
|
||||||
|
bucket_pos[b] = i
|
||||||
|
yield [idx for idx in batch]
|
||||||
|
|
||||||
|
def fill_buckets(self):
|
||||||
|
entries = self.store.entries_iterator()
|
||||||
|
total_dropped = 0
|
||||||
|
|
||||||
|
for entry, index in tqdm.tqdm(entries, total=len(self.store)):
|
||||||
|
if not self._process_entry(entry, index):
|
||||||
|
total_dropped += 1
|
||||||
|
|
||||||
|
for b, values in self.bucket_data.items():
|
||||||
|
# shuffle the entries for extra randomness and to make sure dropped elements are also random
|
||||||
|
random.shuffle(values)
|
||||||
|
|
||||||
|
# make sure the buckets have an exact number of elements for the batch
|
||||||
|
to_drop = len(values) % self.batch_size
|
||||||
|
self.bucket_data[b] = list(values[:len(values) - to_drop])
|
||||||
|
total_dropped += to_drop
|
||||||
|
|
||||||
|
self.total_dropped = total_dropped
|
||||||
|
|
||||||
|
def _process_entry(self, entry: Image.Image, index: int) -> bool:
|
||||||
|
aspect = entry.width / entry.height
|
||||||
|
|
||||||
|
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
|
||||||
|
return False
|
||||||
|
|
||||||
|
best_bucket = self._bucket_interp(aspect)
|
||||||
|
|
||||||
|
if best_bucket is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
bucket = self.buckets[round(float(best_bucket))]
|
||||||
|
|
||||||
|
self.bucket_data[bucket].append(index)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
class AspectBucketSampler(torch.utils.data.Sampler):
|
||||||
|
def __init__(self, bucket: AspectBucket):
|
||||||
|
super().__init__(None)
|
||||||
|
self.bucket = bucket
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
yield from self.bucket.get_batch_iterator()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.bucket.get_batch_count()
|
||||||
|
|
||||||
|
class AspectDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
|
||||||
|
self.store = store
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.ucg = ucg
|
||||||
|
|
||||||
|
self.transforms = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize([0.5], [0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.store)
|
||||||
|
|
||||||
|
def __getitem__(self, item: int):
|
||||||
|
return_dict = {'pixel_values': None, 'input_ids': None}
|
||||||
|
|
||||||
|
image_file = self.store.get_image(item)
|
||||||
|
return_dict['pixel_values'] = self.transforms(image_file)
|
||||||
|
if random.random() > self.ucg:
|
||||||
|
caption_file = self.store.get_caption(item)
|
||||||
|
else:
|
||||||
|
caption_file = ''
|
||||||
|
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
|
||||||
|
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
def collate_fn(self, examples):
|
||||||
|
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
|
||||||
|
pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||||
|
input_ids = [example['input_ids'] for example in examples if example is not None]
|
||||||
|
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True)
|
||||||
|
return {
|
||||||
|
'pixel_values': pixel_values,
|
||||||
|
'input_ids': padded_tokens.input_ids,
|
||||||
|
'attention_mask': padded_tokens.attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||||
|
class EMAModel:
|
||||||
|
"""
|
||||||
|
Exponential Moving Average of models weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
|
||||||
|
parameters = list(parameters)
|
||||||
|
self.shadow_params = [p.clone().detach() for p in parameters]
|
||||||
|
|
||||||
|
self.decay = decay
|
||||||
|
self.optimization_step = 0
|
||||||
|
|
||||||
|
def get_decay(self, optimization_step):
|
||||||
|
"""
|
||||||
|
Compute the decay factor for the exponential moving average.
|
||||||
|
"""
|
||||||
|
value = (1 + optimization_step) / (10 + optimization_step)
|
||||||
|
return 1 - min(self.decay, value)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, parameters):
|
||||||
|
parameters = list(parameters)
|
||||||
|
|
||||||
|
self.optimization_step += 1
|
||||||
|
self.decay = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
|
for s_param, param in zip(self.shadow_params, parameters):
|
||||||
|
if param.requires_grad:
|
||||||
|
tmp = self.decay * (s_param - param)
|
||||||
|
s_param.sub_(tmp)
|
||||||
|
else:
|
||||||
|
s_param.copy_(param)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||||
|
"""
|
||||||
|
Copy current averaged parameters into given collection of parameters.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
updated with the stored moving averages. If `None`, the
|
||||||
|
parameters with which this `ExponentialMovingAverage` was
|
||||||
|
initialized will be used.
|
||||||
|
"""
|
||||||
|
parameters = list(parameters)
|
||||||
|
for s_param, param in zip(self.shadow_params, parameters):
|
||||||
|
param.data.copy_(s_param.data)
|
||||||
|
|
||||||
|
def to(self, device=None, dtype=None) -> None:
|
||||||
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||||
|
Args:
|
||||||
|
device: like `device` argument to `torch.Tensor.to`
|
||||||
|
"""
|
||||||
|
# .to() on the tensors handles None correctly
|
||||||
|
self.shadow_params = [
|
||||||
|
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
||||||
|
for p in self.shadow_params
|
||||||
|
]
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# get device. TODO: support multi-gpu
|
||||||
|
device = 'cpu'
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
print("DEVICE:", device)
|
||||||
|
|
||||||
|
# setup fp16 stuff
|
||||||
|
scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
|
||||||
|
|
||||||
|
# Set seed
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
print('RANDOM SEED:', args.seed)
|
||||||
|
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token)
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token)
|
||||||
|
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(args.model, subfolder='unet', use_auth_token=args.hf_token)
|
||||||
|
|
||||||
|
# Freeze vae and text_encoder
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
optimizer_cls = bnb.optim.AdamW8bit
|
||||||
|
except:
|
||||||
|
print('bitsandbytes not supported, using regular Adam optimizer')
|
||||||
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
else:
|
||||||
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
|
||||||
|
optimizer = optimizer_cls(
|
||||||
|
unet.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(args.adam_beta1, args.adam_beta2),
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
weight_decay=args.adam_weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule='scaled_linear',
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
tensor_format='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
# load dataset
|
||||||
|
|
||||||
|
|
||||||
|
store = ImageStore(args.dataset)
|
||||||
|
dataset = AspectDataset(store, tokenizer)
|
||||||
|
bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||||
|
sampler = AspectBucketSampler(bucket)
|
||||||
|
|
||||||
|
print(f'STORE_LEN: {len(store)}')
|
||||||
|
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=sampler,
|
||||||
|
num_workers=4,
|
||||||
|
collate_fn=dataset.collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
'constant',
|
||||||
|
optimizer=optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||||
|
|
||||||
|
# move models to device
|
||||||
|
vae = vae.to(device, dtype=weight_dtype)
|
||||||
|
unet = unet.to(device, dtype=torch.float32)
|
||||||
|
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# create ema
|
||||||
|
if args.use_ema:
|
||||||
|
ema_unet = EMAModel(unet.parameters())
|
||||||
|
|
||||||
|
print(get_gpu_ram())
|
||||||
|
|
||||||
|
num_steps_per_epoch = len(train_dataloader)
|
||||||
|
progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False)
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
def save_checkpoint():
|
||||||
|
if args.use_ema:
|
||||||
|
ema_unet.copy_to(unet.parameters())
|
||||||
|
pipeline = StableDiffusionPipeline(
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=PNDMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
|
),
|
||||||
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
|
)
|
||||||
|
pipeline.save_pretrained(args.output_path)
|
||||||
|
|
||||||
|
# train!
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
unet.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
# Convert images to latent space
|
||||||
|
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||||
|
latents = latents * 0.18215
|
||||||
|
|
||||||
|
# Sample noise
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
bsz = latents.shape[0]
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
|
||||||
|
|
||||||
|
# Predict the noise residual and compute loss
|
||||||
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||||
|
|
||||||
|
# Backprop
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Update EMA
|
||||||
|
if args.use_ema:
|
||||||
|
ema_unet.step(unet.parameters())
|
||||||
|
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
logs = {
|
||||||
|
"loss": loss.detach().item(),
|
||||||
|
"lr": lr_scheduler.get_last_lr()[0],
|
||||||
|
"epoch": epoch
|
||||||
|
}
|
||||||
|
progress_bar.set_postfix(logs)
|
||||||
|
run.log(logs)
|
||||||
|
|
||||||
|
if global_step % args.save_steps == 0:
|
||||||
|
save_checkpoint()
|
||||||
|
|
||||||
|
if global_step % args.image_log_steps == 0:
|
||||||
|
# get prompt from random batch
|
||||||
|
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||||
|
pipeline = StableDiffusionPipeline(
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=PNDMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
|
),
|
||||||
|
safety_checker=None, # display safety checker to save memory
|
||||||
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
|
).to(device)
|
||||||
|
# inference
|
||||||
|
images = []
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
for _ in range(args.image_log_amount):
|
||||||
|
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
||||||
|
# log images under single caption
|
||||||
|
run.log({'images': images})
|
||||||
|
|
||||||
|
# cleanup so we don't run out of memory
|
||||||
|
del pipeline
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
save_checkpoint()
|
||||||
|
|
||||||
|
print(get_gpu_ram())
|
||||||
|
print('Done!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
# save a sample
|
||||||
|
img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy()
|
||||||
|
img = ((img + 1.0) * 127.5).astype(np.uint8)
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
img.save('sample.png')
|
||||||
|
break
|
||||||
|
"""
|
Loading…
Reference in New Issue