Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ Losses
BaseDualLoss <loss/base_dual_loss.rst>
LpLoss <loss/lp_loss.rst>
PowerLoss <loss/power_loss.rst>
SinkhornLoss <loss/sinkhorn_loss.rst>


Weighting Schemas
--------------------
Expand All @@ -343,4 +345,4 @@ Weighting Schemas
Neural-Tangent-Kernel Weighting <weighting/ntk_weighting.rst>
No Weighting <weighting/no_weighting.rst>
Scalar Weighting <weighting/scalar_weighting.rst>
Self-Adaptive Weighting <weighting/self_adaptive_weighting.rst>
Self-Adaptive Weighting <weighting/self_adaptive_weighting.rst>
11 changes: 11 additions & 0 deletions docs/source/_rst/loss/sinkhorn_loss.rst
Comment thread
guglielmopadula marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Sinkhorn Loss
===============

.. currentmodule:: pina.loss.sinkhorn_loss

.. automodule:: pina._src.loss.sinkhorn_loss
:no-members:

.. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss
:members:
:show-inheritance:
138 changes: 138 additions & 0 deletions pina/_src/loss/sinkhorn_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Module for the SinkhornLoss class."""

import torch
from pina._src.loss.base_dual_loss import BaseDualLoss
from pina._src.core.utils import check_consistency, check_positive_integer


class SinkhornLoss(BaseDualLoss):
r"""
Implementation of the Sinkhorn loss measuring the entropy-regularized
optimal transport distance between two empirical distributions.

Given an input tensor :math:`x` with :math:`N` samples and a target tensor
:math:`y` with :math:`M` samples, both in :math:`\mathbb{R}^D`, the loss is
defined through the entropy-regularized optimal transport problem:

.. math::

W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)}
\langle C, \pi \rangle - \varepsilon H(\pi)

where :math:`\mu` and :math:`\nu` are the empirical distributions associated
with :math:`x` and :math:`y`, :math:`\pi` is a transport plan, and
:math:`\Pi(\mu, \nu)` is the set of admissible transport plans with
marginals :math:`\mu` and :math:`\nu`.

The cost matrix is defined as:

.. math::

C_{ij} = \left\| x_i - y_j \right\|_2^p

and the entropy term is:

.. math::

H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij}

where :math:`\varepsilon > 0` controls the strength of the entropic
regularization.

The Sinkhorn iterations compute the optimal dual potentials :math:`f^\ast`
and :math:`g^\ast` in log space. The regularized optimal transport cost is
then recovered from the dual formulation as:

.. math::

W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle

where :math:`a` and :math:`b` are uniform probability weights over the
:math:`N` input samples and :math:`M` target samples, respectively.

Unlike pointwise losses, the Sinkhorn loss compares whole empirical
distributions. Therefore, the output is always a scalar value.

Smaller values of ``eps`` provide a closer approximation to the true
Wasserstein distance, but may require more Sinkhorn iterations to converge.

.. seealso::

**Original reference:** Patrini, G., Carioni, M., Forr'e, P., Bhargav,
S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019).
*Sinkhorn AutoEncoders*.
In Proceedings of the 35th Conference on Uncertainty in Artificial
Intelligence.
URL: `<https://openreview.net/forum?id=BygNqoR9tm>`_.
"""

def __init__(self, p=2, eps=0.1, iterations=100):
"""
Initialization of the :class:`SinkhornLoss` class.

:param int p: The exponent of the cost function. Default is ``2``.
:param eps: The entropy regularization strength. Smaller values provide
a closer approximation to the unregularized Wasserstein distance,
but may require more iterations for convergence. Default is ``0.1``.
:type eps: int | float
:param int iterations: The number of Sinkhorn iterations.
Default is ``100``.
:raises AssertionError: If ``iterations`` is not a positive integer.
:raises AssertionError: If ``p`` is not a positive integer.
:raises ValueError: If ``eps`` is not a positive numeric value.
"""
# Initialize the base class with mean reduction
super().__init__(reduction="mean")

# Check consistency
check_positive_integer(iterations, strict=True)
check_positive_integer(p, strict=True)
check_consistency(eps, (int, float))
if eps <= 0:
raise ValueError(
f"Expected 'eps' to be strictly positive, but got {eps}."
)

# Initialize parameters
self.iterations = iterations
self.eps = eps
self.p = p

def forward(self, input, target):
"""
Forward method of the loss function.

:param torch.Tensor input: The input tensor.
:param torch.Tensor target: The target tensor.
:return: The computed Sinkhorn loss value.
:rtype: torch.Tensor
"""
# Extract the number of samples in input and target
n, m = input.shape[0], target.shape[0]

# Initialize log-uniform weights for the empirical distributions
log_a = -input.new_tensor(n).log().expand(n)
log_b = -target.new_tensor(m).log().expand(m)

# Initialize dual potentials f and g
f = torch.zeros(n, dtype=input.dtype, device=input.device)
g = torch.zeros(m, dtype=target.dtype, device=target.device)

# Define the cost matrix, shape (n, m)
C = torch.cdist(input, target, p=self.p) ** self.p

# Perform Sinkhorn iterations in log space for numerical stability
for _ in range(self.iterations):

# Update dual potential f with the softmin operation in log space
softmin_f = torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1)
f = self.eps * (log_a - softmin_f)

# Update dual potential g with the softmin operation in log space
softmin_g = torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0)
g = self.eps * (log_b - softmin_g)

# Compute the Sinkhorn loss as the sum of the means of f and g
loss = f.mean() + g.mean()

return self._reduction(loss.unsqueeze(0))
2 changes: 2 additions & 0 deletions pina/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
"BaseDualLoss",
"LpLoss",
"PowerLoss",
"SinkhornLoss",
]

from pina._src.loss.dual_loss_interface import DualLossInterface
from pina._src.loss.base_dual_loss import BaseDualLoss
from pina._src.loss.power_loss import PowerLoss
from pina._src.loss.lp_loss import LpLoss
from pina._src.loss.sinkhorn_loss import SinkhornLoss

# Back-compatibility with version 0.2, to be removed soon
import warnings
Expand Down
54 changes: 54 additions & 0 deletions tests/test_loss/test_sinkhorn_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import pytest
from pina.loss import SinkhornLoss


@pytest.mark.parametrize("p", [1, 2])
@pytest.mark.parametrize("eps", [0.01, 1])
@pytest.mark.parametrize("iterations", [2, 5])
def test_constructor(p, eps, iterations):

# Define the loss
SinkhornLoss(p=p, eps=eps, iterations=iterations)

# Should fail if iterations is not a positive integer
with pytest.raises(AssertionError):
SinkhornLoss(p=p, eps=eps, iterations=0)

# Should fail if p is not a positive integer
with pytest.raises(AssertionError):
SinkhornLoss(p=0, eps=eps, iterations=iterations)

# Should fail if eps is not numeric
with pytest.raises(ValueError):
SinkhornLoss(p=p, eps="invalid", iterations=iterations)

# Should fail if eps is not positive
with pytest.raises(ValueError):
SinkhornLoss(p=p, eps=-0.1, iterations=iterations)


@pytest.mark.parametrize("p", [2, 3])
@pytest.mark.parametrize("eps", [0.1, 1])
@pytest.mark.parametrize("iterations", [2, 5])
@pytest.mark.parametrize(
"input, target",
[
(torch.rand(10, 2), torch.rand(8, 2)),
(torch.rand(5, 3), torch.rand(5, 3)),
(torch.rand(1, 4), torch.rand(7, 4)),
(torch.rand(6, 4), torch.rand(1, 4)),
(torch.rand(3, 1), torch.rand(4, 1)),
],
)
def test_forward(p, eps, iterations, input, target):

# Define the loss
loss = SinkhornLoss(p=p, eps=eps, iterations=iterations)

# Forward pass
value = loss(input, target)

# Check shape
assert value.shape == torch.Size([1])
assert torch.isfinite(value).all()
Loading