diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 407303a..b130561 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -718,6 +718,7 @@ def main(): # Set seed torch.manual_seed(args.seed) random.seed(args.seed) + np.random.seed(args.seed) print('RANDOM SEED:', args.seed) if args.resume: