diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 2d8f6946..c32d77e7 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -7,6 +7,9 @@ from torch import nn from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin +from diffusers import DiffusionPipeline + +from .grad_tts_utils import text_to_sequence def sequence_mask(length, max_length=None): @@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin): logw = self.proj_w(x_dp, x_mask) return mu, logw, x_mask + + +class GradTTS(DiffusionPipeline): + def __init__(self, unet, noise_scheduler): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler) + + @torch.no_grad() + def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None): + if torch_device is None: + torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + pass + + \ No newline at end of file