sup3r.utilities.loss_metrics.SlicedWassersteinLoss#

class SlicedWassersteinLoss(n_projections=1024)[source]#

Bases: Loss

Loss class for sliced wasserstein distance loss

Parameters:

n_projections (int) – number of random 1D projections to use

Note

Experimentally, we get stability in the SW metric when n_projections is at least 30% of the number of projection dimensions, which for us is HWT. This might be computationally expensive for large spatial/temporal sizes so we default to 1024.

Methods

call(y_true, y_pred)

Invokes the Loss instance.

from_config(config)

Instantiates a Loss from its config (output of get_config()).

get_config()

Returns the config dictionary for a Loss instance.

abstract call(y_true, y_pred)#

Invokes the Loss instance.

Args:
y_true: Ground truth values. shape = [batch_size, d0, .. dN],

except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]

y_pred: The predicted values. shape = [batch_size, d0, .. dN]

Returns:

Loss values with the shape [batch_size, d0, .. dN-1].

classmethod from_config(config)#

Instantiates a Loss from its config (output of get_config()).

Args:

config: Output of get_config().

Returns:

A Loss instance.

get_config()#

Returns the config dictionary for a Loss instance.

__call__(x1, x2)[source]#

Sliced Wasserstein distance based on random 1D projections

Parameters:
  • x1 (tf.tensor) – synthetic generator output (n_observations, spatial_1, spatial_2, temporal, features)

  • x2 (tf.tensor) – high resolution data (n_observations, spatial_1, spatial_2, temporal, features)

Returns:

tf.tensor – 0D tensor loss value