being pipeline
This commit is contained in:
parent
986cc9b2f4
commit
7b55d334d5
|
@ -7,6 +7,9 @@ from torch import nn
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.modeling_utils import ModelMixin
|
from diffusers.modeling_utils import ModelMixin
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
from .grad_tts_utils import text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
def sequence_mask(length, max_length=None):
|
||||||
|
@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||||
logw = self.proj_w(x_dp, x_mask)
|
logw = self.proj_w(x_dp, x_mask)
|
||||||
|
|
||||||
return mu, logw, x_mask
|
return mu, logw, x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class GradTTS(DiffusionPipeline):
|
||||||
|
def __init__(self, unet, noise_scheduler):
|
||||||
|
super().__init__()
|
||||||
|
noise_scheduler = noise_scheduler.set_format("pt")
|
||||||
|
self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None):
|
||||||
|
if torch_device is None:
|
||||||
|
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue