57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
import torch
|
|
|
|
|
|
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
|
|
def torch_snr_error(
|
|
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute SNR between y_pred(tensor) and y_real(tensor)
|
|
|
|
SNR can be calcualted as following equation:
|
|
|
|
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
|
|
|
|
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
|
|
|
|
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
|
|
|
|
Args:
|
|
y_pred (torch.Tensor): _description_
|
|
y_real (torch.Tensor): _description_
|
|
reduction (str, optional): _description_. Defaults to 'mean'.
|
|
|
|
Raises:
|
|
ValueError: _description_
|
|
ValueError: _description_
|
|
|
|
Returns:
|
|
torch.Tensor: _description_
|
|
"""
|
|
if y_pred.shape != y_real.shape:
|
|
raise ValueError(
|
|
f"Can not compute snr loss for tensors with different shape. "
|
|
f"({y_pred.shape} and {y_real.shape})"
|
|
)
|
|
reduction = str(reduction).lower()
|
|
|
|
if y_pred.ndim == 1:
|
|
y_pred = y_pred.unsqueeze(0)
|
|
y_real = y_real.unsqueeze(0)
|
|
|
|
y_pred = y_pred.flatten(start_dim=1)
|
|
y_real = y_real.flatten(start_dim=1)
|
|
|
|
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
|
|
signal_power = torch.pow(y_real, 2).sum(dim=-1)
|
|
snr = (noise_power) / (signal_power + 1e-7)
|
|
|
|
if reduction == "mean":
|
|
return torch.mean(snr)
|
|
elif reduction == "sum":
|
|
return torch.sum(snr)
|
|
elif reduction == "none":
|
|
return snr
|
|
else:
|
|
raise ValueError("Unsupported reduction method.")
|