[dreambooth] make collate_fn global (#1547)

make collate_fn global
This commit is contained in:
Suraj Patil 2022-12-06 14:41:53 +01:00 committed by GitHub
parent c228331068
commit 9e1102990a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 29 deletions

View File

@ -304,9 +304,10 @@ class DreamBoothDataset(Dataset):
example["instance_images"] = self.image_transforms(instance_image) example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer( example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt, self.instance_prompt,
padding="do_not_pad",
truncation=True, truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids ).input_ids
if self.class_data_root: if self.class_data_root:
@ -316,14 +317,37 @@ class DreamBoothDataset(Dataset):
example["class_images"] = self.image_transforms(class_image) example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer( example["class_prompt_ids"] = self.tokenizer(
self.class_prompt, self.class_prompt,
padding="do_not_pad",
truncation=True, truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids ).input_ids
return example return example
def collate_fn(examples, with_prior_preservation=False):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
class PromptDataset(Dataset): class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
@ -514,34 +538,12 @@ def main(args):
center_crop=args.center_crop, center_crop=args.center_crop,
) )
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1,
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.