early work on pinned image tensor
This commit is contained in:
parent
6c8d15daab
commit
a47d65799f
|
@ -138,7 +138,59 @@ class EveryDreamBatch(Dataset):
|
|||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
class DataLoaderWithFixedBuffer(torch.utils.data.DataLoader):
|
||||
def __init__(self, dataset, buffer_tensor, batch_size:int, max_pixels: int, buffer_dtype: torch.dtype, device="cuda"):
|
||||
color_channels = 3
|
||||
buffer_size = batch_size * color_channels * max_pixels
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
buffer_tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device).pin_memory()
|
||||
self.buffer_tensor = buffer_tensor
|
||||
logging.info(f"buffer_tensor created with shape: {buffer_tensor.shape}")
|
||||
|
||||
super().__init__(dataset, batch_size=batch_size, shuffle=False, num_workers=min(batch_size, os.cpu_count()), collate_fn=self.fixed_collate_fn)
|
||||
|
||||
def fixed_collate_fn(self, batch):
|
||||
"""
|
||||
Collates images to a pinned buffer returned as a view using actual resolution shape
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
|
||||
# map the image data to the fixed buffer view
|
||||
w, h = images[0].size
|
||||
for i in range(self.batch_size):
|
||||
image = batch["image"][i]
|
||||
self.buffer_tensor[i*self.buffer_size//self.batch_size:(i+1)*self.buffer_size//self.batch_size] = image.view(-1)
|
||||
images = self.buffer_tensor.view(self.batch_size, 3, w, h)
|
||||
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
def build_torch_dataloader2(dataset, batch_size, max_pixels) -> torch.utils.data.DataLoader:
|
||||
dataloader = DataLoaderWithFixedBuffer(
|
||||
dataset,
|
||||
max_pixels=max_pixels,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=min(batch_size, os.cpu_count()),
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size= batch_size,
|
||||
|
@ -148,7 +200,6 @@ def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
|||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
|
|
Loading…
Reference in New Issue