parent
c228331068
commit
9e1102990a
|
@ -304,9 +304,10 @@ class DreamBoothDataset(Dataset):
|
|||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
example["instance_prompt_ids"] = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if self.class_data_root:
|
||||
|
@ -316,14 +317,37 @@ class DreamBoothDataset(Dataset):
|
|||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
self.class_prompt,
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = torch.cat(input_ids, dim=0)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
|
||||
|
@ -514,34 +538,12 @@ def main(args):
|
|||
center_crop=args.center_crop,
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if args.with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad(
|
||||
{"input_ids": input_ids},
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return batch
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
|
|
Loading…
Reference in New Issue