fix a bug in the new version (#957)

remove tensor_format in the new version
This commit is contained in:
Hu Ye 2022-10-26 20:26:17 +08:00 committed by GitHub
parent d9cfe325a5
commit d7d6841406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -375,7 +375,7 @@ def main():
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
# Get the datasets: you can either provide your own training and evaluation files (see below)