early work on pinned image tensor

This commit is contained in:
Victor Hall 2023-09-21 13:44:36 -04:00
parent 6c8d15daab
commit a47d65799f
1 changed files with 53 additions and 2 deletions

View File

@ -138,7 +138,59 @@ class EveryDreamBatch(Dataset):
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
class DataLoaderWithFixedBuffer(torch.utils.data.DataLoader):
def __init__(self, dataset, buffer_tensor, batch_size:int, max_pixels: int, buffer_dtype: torch.dtype, device="cuda"):
color_channels = 3
buffer_size = batch_size * color_channels * max_pixels
self.buffer_size = buffer_size
buffer_tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device).pin_memory()
self.buffer_tensor = buffer_tensor
logging.info(f"buffer_tensor created with shape: {buffer_tensor.shape}")
super().__init__(dataset, batch_size=batch_size, shuffle=False, num_workers=min(batch_size, os.cpu_count()), collate_fn=self.fixed_collate_fn)
def fixed_collate_fn(self, batch):
"""
Collates images to a pinned buffer returned as a view using actual resolution shape
"""
images = [example["image"] for example in batch]
# map the image data to the fixed buffer view
w, h = images[0].size
for i in range(self.batch_size):
image = batch["image"][i]
self.buffer_tensor[i*self.buffer_size//self.batch_size:(i+1)*self.buffer_size//self.batch_size] = image.view(-1)
images = self.buffer_tensor.view(self.batch_size, 3, w, h)
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
def build_torch_dataloader2(dataset, batch_size, max_pixels) -> torch.utils.data.DataLoader:
dataloader = DataLoaderWithFixedBuffer(
dataset,
max_pixels=max_pixels,
batch_size=batch_size,
shuffle=False,
num_workers=min(batch_size, os.cpu_count()),
collate_fn=collate_fn
)
return dataloader
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size= batch_size,
@ -148,7 +200,6 @@ def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
)
return dataloader
def collate_fn(batch):
"""
Collates batches