diff --git a/scripts/conversion_bddm.py b/scripts/conversion_bddm.py new file mode 100644 index 00000000..205ff08e --- /dev/null +++ b/scripts/conversion_bddm.py @@ -0,0 +1,40 @@ + +import argparse +import torch + +from diffusers.pipelines.bddm import DiffWave, BDDMPipeline +from diffusers import DDPMScheduler + + +def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path): + sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] + noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu") + + model = DiffWave() + model.load_state_dict(sd, strict=False) + + ts, _, betas, _ = noise_scheduler_sd + ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist()) + + noise_scheduler = DDPMScheduler( + timesteps=12, + trained_betas=betas, + timestep_values=ts, + clip_sample=False, + tensor_format="np", + ) + + pipeline = BDDMPipeline(model, noise_scheduler) + pipeline.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + + convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path) + +