fix a bug in the new version (#957)
remove tensor_format in the new version
This commit is contained in:
parent
d9cfe325a5
commit
d7d6841406
|
@ -375,7 +375,7 @@ def main():
|
||||||
|
|
||||||
# TODO (patil-suraj): load scheduler using args
|
# TODO (patil-suraj): load scheduler using args
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler(
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||||
|
|
Loading…
Reference in New Issue