Source code for vathos.model.loss.loss

import kornia
import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class DiceLoss(nn.Module): r"""Criterion that computes Sørensen-Dice Coefficient loss. According to [1], we compute the Sørensen-Dice Coefficient as follows: .. math:: \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} where: - :math:`X` expects to be the scores of each class. - :math:`Y` expects to be the one-hot tensor with the class labels. the loss, is finally computed as: .. math:: \text{loss}(x, class) = 1 - \text{Dice}(x, class) [1] """ def __init__(self) -> None: super(DiceLoss, self).__init__() self.eps: float = 1e-6 def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input_sig = torch.sigmoid(input) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_sig * target, dims) cardinality = torch.sum(input_sig + target, dims) dice_score = 2. * intersection / (cardinality + self.eps) return torch.mean(torch.tensor(1.) - dice_score)
[docs]class TverskyLoss(nn.Module): r"""Performs Tversky Loss on Logits According to [1], we compute the Tversky Coefficient as follows: .. math:: \text{S}(P, G, \alpha; \beta) = \frac{|PG|}{|PG| + \alpha |P \ G| + \beta |G \ P|} where: - :math:`P` and :math:`G` are the predicted and ground truth binary labels. - :math:`\alpha` and :math:`\beta` control the magnitude of the penalties for FPs and FNs, respectively. Notes: - :math:`\alpha = \beta = 0.5` => dice coeff - :math:`\alpha = \beta = 1` => tanimoto coeff - :math:`\alpha + \beta = 1` => F beta coeff Reference: [1] """ def __init__(self, alpha: float, beta: float) -> None: super(TverskyLoss, self).__init__() self.alpha: float = alpha self.beta: float = beta self.eps: float = 1e-6 def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input_sig = torch.sigmoid(input) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_sig * target, dims) fps = torch.sum(input_sig * (torch.tensor(1.) - target), dims) fns = torch.sum((torch.tensor(1.) - input_sig) * target, dims) numerator = intersection denominator = intersection + self.alpha * fps + self.beta * fns tversky_loss = numerator / (denominator + self.eps) return torch.mean(torch.tensor(1.) - tversky_loss)
[docs]class BCEDiceLoss(nn.Module): r"""Performs BCE and Dice Loss and adds them both loss = bce_loss + 2 * dice_loss """ def __init__(self) -> None: super(BCEDiceLoss, self).__init__() self.eps = 1e-6 self.dice_loss = DiceLoss() def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: bce_loss = F.binary_cross_entropy_with_logits(input, target) dice_loss = self.dice_loss(input, target) loss = bce_loss + 2*dice_loss return loss
[docs]class BCETverskyLoss(nn.Module): r"""Performs BCE and Tversky Loss and adds them both loss = bce_loss + 2 * tversky_loss """ def __init__(self) -> None: super(BCETverskyLoss, self).__init__() self.eps = 1e-6 self.tversky_loss = TverskyLoss(0.6, 0.5) def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: bce_loss = F.binary_cross_entropy_with_logits(input, target) tversky_loss = self.tversky_loss(input, target) loss = bce_loss + 2*tversky_loss return loss
[docs]class BerHuLoss(nn.Module): r''' Implementation of the BerHu Loss from [1] .. math:: B(y, y') = (1/n) * |y' - y| if |y'-y| <= c B(y, y') = (1/n) * ( (y'-y)^2 + c^2 ) / 2*c othwerwise c = 1/5*max(|y'-y|) [1] [2] ''' def __init__(self, threshold: float = 1./5): super(BerHuLoss, self).__init__() self.threshold = threshold def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input_sig = torch.sigmoid(input) diff = torch.abs(target - input) C = self.threshold * torch.max(diff).item() # if -|y'-y| >= -c then |y'-y| else 0 l_eq = -F.threshold(-diff, -C, 0.) # if diff^2 - c^2 > 0 then diff^2-c^2 + 2c^2 / 2c else -2c^2 + 2c^2 / 2c l_other = (F.threshold(diff**2 - C**2, 0., -2*C**2) + 2*C**2) / 2*C loss = l_eq + l_other loss = torch.mean(loss) return loss
[docs]class GradLoss(nn.Module): r"""Performs Gradient Loss The Image XY Gradients are computed for input and target and the mean L1Loss between these gradients is returned """ def __init__(self): super(GradLoss, self).__init__() # B, C, H, W -> B, C, 2, H, W self.input_grad = kornia.filters.SpatialGradient() self.target_grad = kornia.filters.SpatialGradient() def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input_sig = torch.sigmoid(input) input_grads = self.input_grad(input_sig) target_grads = self.target_grad(target) loss = torch.mean(torch.abs(target_grads - input_grads)) return loss
[docs]class SSIMLoss(nn.Module): r"""Performs SSIM Loss window sizes uses are 5x5 and 11x11 we tried adding other window sizes too, but there wasn't a significant benefit .. note:: we do ssim loss for various window sizes, add them and return the mean """ def __init__(self): super(SSIMLoss, self).__init__() self.ssim_loss_5x5 = kornia.losses.SSIM(5, reduction='none') self.ssim_loss_11x11 = kornia.losses.SSIM(11, reduction='none') def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input_sig = torch.sigmoid(input) loss_5 = self.ssim_loss_5x5(input_sig, target) loss_11 = self.ssim_loss_11x11(input_sig, target) return torch.mean(loss_5) + torch.mean(loss_11)
[docs]class RMSEwSSIMLoss(nn.Module): r'''Performs RMSE and SSIM Loss loss = :math:`\sqrt{\text{ssim_loss} + 4\times \text{rmse_loss}}` ''' def __init__(self): super(RMSEwSSIMLoss, self).__init__() self.ssim_loss = SSIMLoss() self.rmse_loss = RMSELoss() def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss = torch.sqrt(self.ssim_loss(input, target)) + \ 4*self.rmse_loss(input, target) return loss
[docs]def iou(outputs: torch.Tensor, labels: torch.Tensor): eps: float = 1e-6 dims = (1, 2, 3) outputs = torch.sigmoid(outputs) intersection = torch.sum((outputs * labels), dims) union = torch.sum((outputs + labels), dims) iou = intersection / (union + eps) miou = torch.mean(iou) return miou
[docs]def rmse(outputs: torch.Tensor, labels: torch.Tensor): rmse_loss = RMSELoss() loss = rmse_loss(outputs, labels) return loss
[docs]class RMSELoss(nn.Module): r"""Performs RMSE Loss we simply sigmoid the input, pass it through `nn.MSELoss` and then do a `torch.sqrt` on it """ def __init__(self, eps=1e-6): super().__init__() self.mse = nn.MSELoss() self.eps = eps def forward(self, input, target): input_sig = torch.sigmoid(input) loss = torch.sqrt(self.mse(input_sig, target) + self.eps) return loss