add conversion script for BDDMPipeline
This commit is contained in:
parent
a1b5ef5ddc
commit
ab946575b1
|
@ -0,0 +1,40 @@
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
|
||||||
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
|
||||||
|
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
||||||
|
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
model = DiffWave()
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
ts, _, betas, _ = noise_scheduler_sd
|
||||||
|
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
timesteps=12,
|
||||||
|
trained_betas=betas,
|
||||||
|
timestep_values=ts,
|
||||||
|
clip_sample=False,
|
||||||
|
tensor_format="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = BDDMPipeline(model, noise_scheduler)
|
||||||
|
pipeline.save_pretrained(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, required=True)
|
||||||
|
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
|
||||||
|
parser.add_argument("--output_path", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue