Merge pull request #6620 from guaneec/varsize_batch

Enable batch_size>1 for mixed-sized training
This commit is contained in:
AUTOMATIC1111 2023-01-13 14:03:31 +03:00 committed by GitHub
commit 486bda9b33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 32 additions and 4 deletions

View File

@ -3,8 +3,10 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices
import random import random
import tqdm import tqdm
@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified' assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out self.tag_drop_out = tag_drop_out
groups = defaultdict(list)
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast(): with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
groups[image.size].append(len(self.dataset))
self.dataset.append(entry) self.dataset.append(entry)
del torchdata del torchdata
del latent_dist del latent_dist
del latent_sample del latent_sample
self.length = len(self.dataset) self.length = len(self.dataset)
self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset." assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length) self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size) self.gradient_step = min(gradient_step, self.length // self.batch_size)
@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry return entry
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self):
b = self.batch_size
for g in self.groups:
shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
shuffle(batches)
yield from batches
class PersonalizedDataLoader(DataLoader): class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random": if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random self.collate_fn = collate_wrapper_random
else: else: