parent
c228331068
commit
9e1102990a
|
@ -304,9 +304,10 @@ class DreamBoothDataset(Dataset):
|
||||||
example["instance_images"] = self.image_transforms(instance_image)
|
example["instance_images"] = self.image_transforms(instance_image)
|
||||||
example["instance_prompt_ids"] = self.tokenizer(
|
example["instance_prompt_ids"] = self.tokenizer(
|
||||||
self.instance_prompt,
|
self.instance_prompt,
|
||||||
padding="do_not_pad",
|
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
padding="max_length",
|
||||||
max_length=self.tokenizer.model_max_length,
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
).input_ids
|
).input_ids
|
||||||
|
|
||||||
if self.class_data_root:
|
if self.class_data_root:
|
||||||
|
@ -316,14 +317,37 @@ class DreamBoothDataset(Dataset):
|
||||||
example["class_images"] = self.image_transforms(class_image)
|
example["class_images"] = self.image_transforms(class_image)
|
||||||
example["class_prompt_ids"] = self.tokenizer(
|
example["class_prompt_ids"] = self.tokenizer(
|
||||||
self.class_prompt,
|
self.class_prompt,
|
||||||
padding="do_not_pad",
|
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
padding="max_length",
|
||||||
max_length=self.tokenizer.model_max_length,
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
).input_ids
|
).input_ids
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(examples, with_prior_preservation=False):
|
||||||
|
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||||
|
pixel_values = [example["instance_images"] for example in examples]
|
||||||
|
|
||||||
|
# Concat class and instance examples for prior preservation.
|
||||||
|
# We do this to avoid doing two forward passes.
|
||||||
|
if with_prior_preservation:
|
||||||
|
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||||
|
pixel_values += [example["class_images"] for example in examples]
|
||||||
|
|
||||||
|
pixel_values = torch.stack(pixel_values)
|
||||||
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
|
input_ids = torch.cat(input_ids, dim=0)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class PromptDataset(Dataset):
|
class PromptDataset(Dataset):
|
||||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||||
|
|
||||||
|
@ -514,34 +538,12 @@ def main(args):
|
||||||
center_crop=args.center_crop,
|
center_crop=args.center_crop,
|
||||||
)
|
)
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
|
||||||
pixel_values = [example["instance_images"] for example in examples]
|
|
||||||
|
|
||||||
# Concat class and instance examples for prior preservation.
|
|
||||||
# We do this to avoid doing two forward passes.
|
|
||||||
if args.with_prior_preservation:
|
|
||||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
|
||||||
pixel_values += [example["class_images"] for example in examples]
|
|
||||||
|
|
||||||
pixel_values = torch.stack(pixel_values)
|
|
||||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
|
||||||
|
|
||||||
input_ids = tokenizer.pad(
|
|
||||||
{"input_ids": input_ids},
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
return_tensors="pt",
|
|
||||||
).input_ids
|
|
||||||
|
|
||||||
batch = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"pixel_values": pixel_values,
|
|
||||||
}
|
|
||||||
return batch
|
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
train_dataset,
|
||||||
|
batch_size=args.train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||||
|
num_workers=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Scheduler and math around the number of training steps.
|
# Scheduler and math around the number of training steps.
|
||||||
|
|
Loading…
Reference in New Issue