Fix padding in dreambooth (#1030)

This commit is contained in:
Yuta Hayashibe 2022-11-03 00:37:05 +09:00 committed by GitHub
parent 5cd29d623a
commit 33c487455e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -494,7 +494,12 @@ def main(args):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
batch = {
"input_ids": input_ids,