finish vp
This commit is contained in:
parent
dc6d028654
commit
ba264419f4
|
@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
continuous=continuous,
|
continuous=continuous,
|
||||||
)
|
)
|
||||||
self.act = act = get_act(nonlinearity)
|
self.act = act = get_act(nonlinearity)
|
||||||
self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10)))
|
|
||||||
|
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
|
@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
|
|
||||||
self.all_modules = nn.ModuleList(modules)
|
self.all_modules = nn.ModuleList(modules)
|
||||||
|
|
||||||
def forward(self, x, time_cond):
|
def forward(self, x, time_cond, sigmas=None):
|
||||||
# timestep/noise_level embedding; only for continuous training
|
# timestep/noise_level embedding; only for continuous training
|
||||||
modules = self.all_modules
|
modules = self.all_modules
|
||||||
m_idx = 0
|
m_idx = 0
|
||||||
|
@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||||
elif self.embedding_type == "positional":
|
elif self.embedding_type == "positional":
|
||||||
# Sinusoidal positional embeddings.
|
# Sinusoidal positional embeddings.
|
||||||
timesteps = time_cond
|
timesteps = time_cond
|
||||||
used_sigmas = self.sigmas[time_cond.long()]
|
used_sigmas = sigmas
|
||||||
temb = get_timestep_embedding(timesteps, self.nf)
|
temb = get_timestep_embedding(timesteps, self.nf)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,27 +17,21 @@ class ScoreSdeVpPipeline(DiffusionPipeline):
|
||||||
channels = self.model.config.num_channels
|
channels = self.model.config.num_channels
|
||||||
shape = (1, channels, img_size, img_size)
|
shape = (1, channels, img_size, img_size)
|
||||||
|
|
||||||
beta_min, beta_max = 0.1, 20
|
|
||||||
|
|
||||||
model = self.model.to(device)
|
model = self.model.to(device)
|
||||||
|
|
||||||
x = torch.randn(*shape).to(device)
|
x = torch.randn(*shape).to(device)
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
for i, t in enumerate(self.scheduler.timesteps):
|
for t in self.scheduler.timesteps:
|
||||||
t = t * torch.ones(shape[0], device=device)
|
t = t * torch.ones(shape[0], device=device)
|
||||||
sigma_t = t * (num_inference_steps - 1)
|
scaled_t = t * (num_inference_steps - 1)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = model(x, sigma_t)
|
result = model(x, scaled_t)
|
||||||
|
|
||||||
log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min
|
|
||||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
|
||||||
result = -result / std[:, None, None, None]
|
|
||||||
|
|
||||||
x, x_mean = self.scheduler.step_pred(result, x, t)
|
x, x_mean = self.scheduler.step_pred(result, x, t)
|
||||||
|
|
||||||
x_mean = (x_mean + 1.) / 2.
|
x_mean = (x_mean + 1.0) / 2.0
|
||||||
|
|
||||||
return x_mean
|
return x_mean
|
||||||
|
|
|
@ -20,6 +20,6 @@ from .scheduling_ddim import DDIMScheduler
|
||||||
from .scheduling_ddpm import DDPMScheduler
|
from .scheduling_ddpm import DDPMScheduler
|
||||||
from .scheduling_grad_tts import GradTTSScheduler
|
from .scheduling_grad_tts import GradTTSScheduler
|
||||||
from .scheduling_pndm import PNDMScheduler
|
from .scheduling_pndm import PNDMScheduler
|
||||||
from .scheduling_utils import SchedulerMixin
|
|
||||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||||
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
|
@ -52,6 +52,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def step_pred(self, result, x, t):
|
def step_pred(self, result, x, t):
|
||||||
|
# TODO(Patrick) better comments + non-PyTorch
|
||||||
t = t * torch.ones(x.shape[0], device=x.device)
|
t = t * torch.ones(x.shape[0], device=x.device)
|
||||||
timestep = (t * (len(self.timesteps) - 1)).long()
|
timestep = (t * (len(self.timesteps) - 1)).long()
|
||||||
|
|
||||||
|
@ -70,6 +71,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
return x, x_mean
|
return x, x_mean
|
||||||
|
|
||||||
def step_correct(self, result, x):
|
def step_correct(self, result, x):
|
||||||
|
# TODO(Patrick) better comments + non-PyTorch
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
|
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
|
||||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||||
|
|
|
@ -40,16 +40,25 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||||
|
|
||||||
def step_pred(self, result, x, t):
|
def step_pred(self, result, x, t):
|
||||||
dt = -1. / len(self.timesteps)
|
# TODO(Patrick) better comments + non-PyTorch
|
||||||
z = torch.randn_like(x)
|
# postprocess model result
|
||||||
|
log_mean_coeff = (
|
||||||
|
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
||||||
|
)
|
||||||
|
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
|
||||||
|
result = -result / std[:, None, None, None]
|
||||||
|
|
||||||
beta_t = self.beta_min + t * (self.beta_max - self.beta_min)
|
# compute
|
||||||
|
dt = -1.0 / len(self.timesteps)
|
||||||
|
|
||||||
|
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
|
||||||
drift = -0.5 * beta_t[:, None, None, None] * x
|
drift = -0.5 * beta_t[:, None, None, None] * x
|
||||||
diffusion = torch.sqrt(beta_t)
|
diffusion = torch.sqrt(beta_t)
|
||||||
|
|
||||||
drift = drift - diffusion[:, None, None, None] ** 2 * result
|
drift = drift - diffusion[:, None, None, None] ** 2 * result
|
||||||
|
|
||||||
x_mean = x + drift * dt
|
x_mean = x + drift * dt
|
||||||
|
|
||||||
|
# add noise
|
||||||
|
z = torch.randn_like(x)
|
||||||
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
|
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
|
||||||
|
|
||||||
return x, x_mean
|
return x, x_mean
|
||||||
|
|
|
@ -746,8 +746,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_score_sde_vp_pipeline(self):
|
def test_score_sde_vp_pipeline(self):
|
||||||
|
|
||||||
model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp")
|
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
|
||||||
scheduler = ScoreSdeVpScheduler()
|
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")
|
||||||
|
|
||||||
sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)
|
sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue