Remove the usage of numpy in up/down sample_2d (#503)
* Fix PT up/down sample_2d * empty commit * style * style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
c727a6a5fb
commit
c0493723f7
|
@ -1,6 +1,5 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -134,10 +133,10 @@ class FirUpsample2D(nn.Module):
|
||||||
kernel = [1] * factor
|
kernel = [1] * factor
|
||||||
|
|
||||||
# setup kernel
|
# setup kernel
|
||||||
kernel = np.asarray(kernel, dtype=np.float32)
|
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||||
if kernel.ndim == 1:
|
if kernel.ndim == 1:
|
||||||
kernel = np.outer(kernel, kernel)
|
kernel = torch.outer(kernel, kernel)
|
||||||
kernel /= np.sum(kernel)
|
kernel /= torch.sum(kernel)
|
||||||
|
|
||||||
kernel = kernel * (gain * (factor**2))
|
kernel = kernel * (gain * (factor**2))
|
||||||
|
|
||||||
|
@ -219,10 +218,10 @@ class FirDownsample2D(nn.Module):
|
||||||
kernel = [1] * factor
|
kernel = [1] * factor
|
||||||
|
|
||||||
# setup kernel
|
# setup kernel
|
||||||
kernel = np.asarray(kernel, dtype=np.float32)
|
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||||
if kernel.ndim == 1:
|
if kernel.ndim == 1:
|
||||||
kernel = np.outer(kernel, kernel)
|
kernel = torch.outer(kernel, kernel)
|
||||||
kernel /= np.sum(kernel)
|
kernel /= torch.sum(kernel)
|
||||||
|
|
||||||
kernel = kernel * gain
|
kernel = kernel * gain
|
||||||
|
|
||||||
|
@ -391,16 +390,14 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
|
||||||
if kernel is None:
|
if kernel is None:
|
||||||
kernel = [1] * factor
|
kernel = [1] * factor
|
||||||
|
|
||||||
kernel = np.asarray(kernel, dtype=np.float32)
|
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||||
if kernel.ndim == 1:
|
if kernel.ndim == 1:
|
||||||
kernel = np.outer(kernel, kernel)
|
kernel = torch.outer(kernel, kernel)
|
||||||
kernel /= np.sum(kernel)
|
kernel /= torch.sum(kernel)
|
||||||
|
|
||||||
kernel = kernel * (gain * (factor**2))
|
kernel = kernel * (gain * (factor**2))
|
||||||
p = kernel.shape[0] - factor
|
p = kernel.shape[0] - factor
|
||||||
return upfirdn2d_native(
|
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
||||||
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downsample_2d(x, kernel=None, factor=2, gain=1):
|
def downsample_2d(x, kernel=None, factor=2, gain=1):
|
||||||
|
@ -425,14 +422,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
|
||||||
if kernel is None:
|
if kernel is None:
|
||||||
kernel = [1] * factor
|
kernel = [1] * factor
|
||||||
|
|
||||||
kernel = np.asarray(kernel, dtype=np.float32)
|
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||||
if kernel.ndim == 1:
|
if kernel.ndim == 1:
|
||||||
kernel = np.outer(kernel, kernel)
|
kernel = torch.outer(kernel, kernel)
|
||||||
kernel /= np.sum(kernel)
|
kernel /= torch.sum(kernel)
|
||||||
|
|
||||||
kernel = kernel * gain
|
kernel = kernel * gain
|
||||||
p = kernel.shape[0] - factor
|
p = kernel.shape[0] - factor
|
||||||
return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||||
|
|
Loading…
Reference in New Issue