Fix padding in dreambooth (#1030)
This commit is contained in:
parent
5cd29d623a
commit
33c487455e
|
@ -494,7 +494,12 @@ def main(args):
|
||||||
pixel_values = torch.stack(pixel_values)
|
pixel_values = torch.stack(pixel_values)
|
||||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
input_ids = tokenizer.pad(
|
||||||
|
{"input_ids": input_ids},
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|
Loading…
Reference in New Issue