add conversion script for BDDMPipeline

This commit is contained in:
patil-suraj 2022-07-01 17:44:38 +02:00
parent a1b5ef5ddc
commit ab946575b1
1 changed files with 40 additions and 0 deletions

View File

@ -0,0 +1,40 @@
import argparse
import torch
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
from diffusers import DDPMScheduler
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
model = DiffWave()
model.load_state_dict(sd, strict=False)
ts, _, betas, _ = noise_scheduler_sd
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
noise_scheduler = DDPMScheduler(
timesteps=12,
trained_betas=betas,
timestep_values=ts,
clip_sample=False,
tensor_format="np",
)
pipeline = BDDMPipeline(model, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)