Metrics

class fairret.metric.LinearFractionalParity

Bases: torchmetrics.metric.Metric

Metric that assesses the fairness of a model’s predictions by comparing the gaps between the provided LinearFractionalStatistic for every sensitive feature.

The metric maintains two pairs of running sums: one for the statistic for every sensitive feature, and one for the overall statistic. Each pair of running sums consists of the numerator and the denominator for those statistics. Observations are added to these sums by calling the update() method. The final fairness gap is computed by calling the compute() method.

The class is implemented as a subclass of torchmetrics.Metric, so the torchmetrics package is required.

Warning

A separate reset() call is required to reset the internal state of the metric between epochs.

Warning

It is advised not to mix metrics of this class with different statistics in a single torchmetrics.MetricCollection with compute_groups=True, as this can lead to hard-to-debug errors.

__init__(statistic, stat_shape, gap_fn=gap_relative_abs_max, **torchmetrics_kwargs)
Parameters:
  • statistic (fairret.statistic.linear_fractional.LinearFractionalStatistic) – the LinearFractionalStatistic that should be evaluated.

  • stat_shape (int | Tuple[int]) – the shape of the statistic, excluding the batch dimension. For example, a single statistic computed for every sensitive feature would have a shape of (S,) with S the number of sensitive features. If the statistic is a stacked statistic, the shape should be (K, S) with K the number of statistics in the stack.

  • gap_fn (Callable[[torch.Tensor, float], float]) – the function that computes the gaps between the statistic for every sensitive feature and the overall statistic. The default is the absolute maximum of the relative gaps.

  • **torchmetrics_kwargs (Any) – Any additional keyword arguments that should be passed to torchmetrics.Metric.

update(pred, sens, *stat_args, **stat_kwargs)

Update the running sums for the numerator and denominator of the groupwise and overall statistics with a new batch of predictions and sensitive features.

Parameters:
  • pred (torch.Tensor) – Predictions of shape \((N, 1)\), as we assume to be performing binary classification or regression.

  • sens (torch.Tensor) – Sensitive features of shape \((N, S)\) with S the number of sensitive features.

  • *stat_args (Any) – All arguments used by the statistic that this metric computes.

  • **stat_kwargs (Any) – All keyword arguments used by the statistic that this metric computes.

compute()

Divide the running sums of the numerator and denominator of the groupwise and overall statistics and compute the final gaps between the groupwise and overall statistics, according to the gap_fn.

Warning

This does NOT reset the internal state of the metric. A separate .reset() call is required to do so.

Returns:

The final fairness gap.

Return type:

float