Fairness Metrics

A key insight that makes this library so flexible with respect to how fairness is defined, is that many fairness definitions simply compare a statistic between groups.

In this notebook, we express many well-known fairness definitions in terms of statistics through the use of the LinearFractionalParity class. The class is implemented as a torchmetrics Metric, which allows it to be integrated into any training or evaluation loop in Pytorch. All that it needs is a statistic to compare.

To start, make sure the library is installed:

[1]:
import torchmetrics

If it isn’t, you can install it with:

pip install torchmetrics

[2]:
from fairret.metric import LinearFractionalParity

The most simple statistic to compare is the positive rate, i.e. the rate at which positive predictions are made for each sensitive feature. Equality in these positive rates is typically referred to as demographic parity (a.k.a. statistical parity or equal acceptance rate).

Hence, we can express the extent to which demographic parity holds by passing the PositiveRate statistic to a LinearFractionalParity metric:

[3]:
from fairret.statistic import PositiveRate
demographic_parity = LinearFractionalParity(PositiveRate())
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 2
      1 from fairret.statistic import PositiveRate
----> 2 demographic_parity = LinearFractionalParity(PositiveRate())

TypeError: LinearFractionalParity.__init__() missing 1 required positional argument: 'stat_shape'

Oops, we also need to provide the shape of the statistic. Most often, this will simply be the number of sensitive features. For example, the PositiveRate statistic just computes a single value for each sensitive feature.

To know the number of sensitive features, we need to define some data first:

[4]:
import torch
torch.manual_seed(0)

feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])
label = torch.tensor([[0.], [1.], [0.], [1.]])

n_sensitive_features = sens.shape[1]
print(f"Number of sensitive features: {n_sensitive_features}")
Number of sensitive features: 2

We can now construct the metric:

[5]:
demographic_parity = LinearFractionalParity(PositiveRate(), stat_shape=(n_sensitive_features,))

LinearFractionalParity follows the exact interface as all other Metric classes in torchmetrics.

For all details of this interface, check out the torchmetrics documentation.

Basically, it follows a three-step approach: 1. Call metric.update(args), where args in our case are the arguments necessary to compute the statistic. 2. Call metric.compute(), which returns violation of the fairness definition. 3. Call metric.reset(), which resets the initial state of the metric.

If you need to use any other torchmetrics settings, such as compute_with_cache, you can pass them as keyword arguments to the LinearFractionalParity class upon initialization.

Example

Let’s train a model and keep track of the demographic parity.

Without fairret:

[6]:
h_layer_dim = 16
lr = 1e-3
batch_size = 1024

torch.manual_seed(0)

def build_model():
    model = torch.nn.Sequential(
        torch.nn.Linear(feat.shape[1], h_layer_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(h_layer_dim, 1)
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    return model, optimizer

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)
[7]:
import numpy as np

nb_epochs = 100
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss.backward()

        pred = torch.sigmoid(logit)
        demographic_parity.update(pred, batch_sens)

        optimizer.step()
        losses.append(loss.item())
    dp_for_epoch = demographic_parity.compute()
    demographic_parity.reset()
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}, dp: {dp_for_epoch}")
Epoch: 0, loss: 0.7091795206069946, dp: 0.03126168251037598
Epoch: 1, loss: 0.7061765193939209, dp: 0.02563762664794922
Epoch: 2, loss: 0.7033581733703613, dp: 0.020147204399108887
Epoch: 3, loss: 0.7007156610488892, dp: 0.014800786972045898
Epoch: 4, loss: 0.6982340812683105, dp: 0.009598910808563232
Epoch: 5, loss: 0.6959078907966614, dp: 0.00453948974609375
Epoch: 6, loss: 0.6937355995178223, dp: 0.00037485361099243164
Epoch: 7, loss: 0.6917158365249634, dp: 0.005139470100402832
Epoch: 8, loss: 0.6898466944694519, dp: 0.009749293327331543
Epoch: 9, loss: 0.6881252527236938, dp: 0.014199256896972656
Epoch: 10, loss: 0.6865478754043579, dp: 0.01848423480987549
Epoch: 11, loss: 0.6851094961166382, dp: 0.022599458694458008
Epoch: 12, loss: 0.6838041543960571, dp: 0.02654099464416504
Epoch: 13, loss: 0.6826250553131104, dp: 0.030305147171020508
Epoch: 14, loss: 0.6815641522407532, dp: 0.03388887643814087
Epoch: 15, loss: 0.6806124448776245, dp: 0.037290215492248535
Epoch: 16, loss: 0.6797604560852051, dp: 0.040507614612579346
Epoch: 17, loss: 0.6789975762367249, dp: 0.04354041814804077
Epoch: 18, loss: 0.6783132553100586, dp: 0.04638934135437012
Epoch: 19, loss: 0.6776963472366333, dp: 0.04905581474304199
Epoch: 20, loss: 0.6771360039710999, dp: 0.051542043685913086
Epoch: 21, loss: 0.6766215562820435, dp: 0.05385148525238037
Epoch: 22, loss: 0.6761429309844971, dp: 0.055988192558288574
Epoch: 23, loss: 0.6756909489631653, dp: 0.05795770883560181
Epoch: 24, loss: 0.6752569675445557, dp: 0.05976557731628418
Epoch: 25, loss: 0.6748337745666504, dp: 0.06141817569732666
Epoch: 26, loss: 0.674415111541748, dp: 0.06292301416397095
Epoch: 27, loss: 0.673996090888977, dp: 0.06428724527359009
Epoch: 28, loss: 0.6735726594924927, dp: 0.06551891565322876
Epoch: 29, loss: 0.6731564998626709, dp: 0.06661409139633179
Epoch: 30, loss: 0.6727579236030579, dp: 0.06756091117858887
Epoch: 31, loss: 0.672345757484436, dp: 0.06839364767074585
Epoch: 32, loss: 0.6719199419021606, dp: 0.0691213607788086
Epoch: 33, loss: 0.6714813709259033, dp: 0.06975340843200684
Epoch: 34, loss: 0.6710319519042969, dp: 0.0702979564666748
Epoch: 35, loss: 0.6705741882324219, dp: 0.07076394557952881
Epoch: 36, loss: 0.6701083779335022, dp: 0.07116150856018066
Epoch: 37, loss: 0.669636607170105, dp: 0.07149994373321533
Epoch: 38, loss: 0.6691610217094421, dp: 0.07178771495819092
Epoch: 39, loss: 0.6686834096908569, dp: 0.07203388214111328
Epoch: 40, loss: 0.6682056188583374, dp: 0.0722469687461853
Epoch: 41, loss: 0.6677289009094238, dp: 0.07243525981903076
Epoch: 42, loss: 0.667254626750946, dp: 0.07260710000991821
Epoch: 43, loss: 0.6667835712432861, dp: 0.07276999950408936
Epoch: 44, loss: 0.6663164496421814, dp: 0.07293164730072021
Epoch: 45, loss: 0.6658533811569214, dp: 0.07309883832931519
Epoch: 46, loss: 0.6653945446014404, dp: 0.07327830791473389
Epoch: 47, loss: 0.6649397015571594, dp: 0.0734756588935852
Epoch: 48, loss: 0.6644884347915649, dp: 0.07369661331176758
Epoch: 49, loss: 0.6640403270721436, dp: 0.07394576072692871
Epoch: 50, loss: 0.6635947227478027, dp: 0.07422733306884766
Epoch: 51, loss: 0.6631510257720947, dp: 0.07454437017440796
Epoch: 52, loss: 0.6628453135490417, dp: 0.07504117488861084
Epoch: 53, loss: 0.6625917553901672, dp: 0.07558834552764893
Epoch: 54, loss: 0.6623181104660034, dp: 0.07614338397979736
Epoch: 55, loss: 0.6620256900787354, dp: 0.07671010494232178
Epoch: 56, loss: 0.6617173552513123, dp: 0.07728838920593262
Epoch: 57, loss: 0.6614043116569519, dp: 0.07785654067993164
Epoch: 58, loss: 0.6610796451568604, dp: 0.07842350006103516
Epoch: 59, loss: 0.6607442498207092, dp: 0.07899010181427002
Epoch: 60, loss: 0.6603990793228149, dp: 0.07955652475357056
Epoch: 61, loss: 0.6600450277328491, dp: 0.08012241125106812
Epoch: 62, loss: 0.6596829295158386, dp: 0.08068704605102539
Epoch: 63, loss: 0.6593135595321655, dp: 0.08124935626983643
Epoch: 64, loss: 0.6589376330375671, dp: 0.08180773258209229
Epoch: 65, loss: 0.6585558652877808, dp: 0.08236086368560791
Epoch: 66, loss: 0.6581688523292542, dp: 0.08290684223175049
Epoch: 67, loss: 0.6577771306037903, dp: 0.08344399929046631
Epoch: 68, loss: 0.6574320793151855, dp: 0.08401530981063843
Epoch: 69, loss: 0.6571431756019592, dp: 0.08470660448074341
Epoch: 70, loss: 0.6568371653556824, dp: 0.08542871475219727
Epoch: 71, loss: 0.6565203666687012, dp: 0.08615553379058838
Epoch: 72, loss: 0.6561905145645142, dp: 0.08688771724700928
Epoch: 73, loss: 0.6558488607406616, dp: 0.08762180805206299
Epoch: 74, loss: 0.65549635887146, dp: 0.08835643529891968
Epoch: 75, loss: 0.6551340818405151, dp: 0.08908939361572266
Epoch: 76, loss: 0.6547629237174988, dp: 0.08981943130493164
Epoch: 77, loss: 0.6544535160064697, dp: 0.09061932563781738
Epoch: 78, loss: 0.6541627645492554, dp: 0.09137928485870361
Epoch: 79, loss: 0.6538523435592651, dp: 0.09210765361785889
Epoch: 80, loss: 0.6535260677337646, dp: 0.09280455112457275
Epoch: 81, loss: 0.6531944274902344, dp: 0.09344327449798584
Epoch: 82, loss: 0.6528521776199341, dp: 0.09403371810913086
Epoch: 83, loss: 0.6525000333786011, dp: 0.09457921981811523
Epoch: 84, loss: 0.652138888835907, dp: 0.09508275985717773
Epoch: 85, loss: 0.6518597602844238, dp: 0.09564340114593506
Epoch: 86, loss: 0.6515651345252991, dp: 0.09620797634124756
Epoch: 87, loss: 0.6512539982795715, dp: 0.09679919481277466
Epoch: 88, loss: 0.6509299874305725, dp: 0.09739184379577637
Epoch: 89, loss: 0.650594174861908, dp: 0.09798645973205566
Epoch: 90, loss: 0.6502466797828674, dp: 0.09860563278198242
Epoch: 91, loss: 0.6498894691467285, dp: 0.09922534227371216
Epoch: 92, loss: 0.6495950818061829, dp: 0.09992420673370361
Epoch: 93, loss: 0.6493034362792969, dp: 0.10058420896530151
Epoch: 94, loss: 0.6489962339401245, dp: 0.10119062662124634
Epoch: 95, loss: 0.6486777067184448, dp: 0.10174638032913208
Epoch: 96, loss: 0.6483432650566101, dp: 0.10228461027145386
Epoch: 97, loss: 0.6479994058609009, dp: 0.10277950763702393
Epoch: 98, loss: 0.6476455330848694, dp: 0.10323858261108398
Epoch: 99, loss: 0.6473514437675476, dp: 0.10378742218017578

With fairret:

[8]:
from fairret.loss import NormLoss
fairness_strength = 1
norm_loss = NormLoss(PositiveRate())

model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss += fairness_strength * norm_loss(logit, batch_sens)
        loss.backward()

        pred = torch.sigmoid(logit)
        demographic_parity.update(pred, batch_sens)

        optimizer.step()
        losses.append(loss.item())
    dp_for_epoch = demographic_parity.compute()
    demographic_parity.reset()
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}, dp: {dp_for_epoch}")
Epoch: 0, loss: 0.7429860830307007, dp: 0.0319594144821167
Epoch: 1, loss: 0.7368547916412354, dp: 0.02823483943939209
Epoch: 2, loss: 0.7306909561157227, dp: 0.024389266967773438
Epoch: 3, loss: 0.7244950532913208, dp: 0.020473718643188477
Epoch: 4, loss: 0.7182670831680298, dp: 0.016491293907165527
Epoch: 5, loss: 0.7120081186294556, dp: 0.012444257736206055
Epoch: 6, loss: 0.7057176828384399, dp: 0.008333325386047363
Epoch: 7, loss: 0.6993964910507202, dp: 0.004159450531005859
Epoch: 8, loss: 0.6933507323265076, dp: 7.653236389160156e-05
Epoch: 9, loss: 0.6987630128860474, dp: 0.0022329092025756836
Epoch: 10, loss: 0.700627326965332, dp: 0.0029762983322143555
Epoch: 11, loss: 0.6999425292015076, dp: 0.0027066469192504883
Epoch: 12, loss: 0.697374701499939, dp: 0.0016865730285644531
Epoch: 13, loss: 0.693393886089325, dp: 9.85860824584961e-05
Epoch: 14, loss: 0.6960413455963135, dp: 0.0019237995147705078
Epoch: 15, loss: 0.6981306076049805, dp: 0.003303050994873047
Epoch: 16, loss: 0.6993909478187561, dp: 0.004129886627197266
Epoch: 17, loss: 0.6999239921569824, dp: 0.004477024078369141
Epoch: 18, loss: 0.6998161673545837, dp: 0.004403471946716309
Epoch: 19, loss: 0.6991404891014099, dp: 0.0039577484130859375
Epoch: 20, loss: 0.697959303855896, dp: 0.0031795501708984375
Epoch: 21, loss: 0.696326732635498, dp: 0.0021026134490966797
Epoch: 22, loss: 0.6942894458770752, dp: 0.0007546544075012207
Epoch: 23, loss: 0.6952502727508545, dp: 0.0008406639099121094
Epoch: 24, loss: 0.6972097754478455, dp: 0.0016245245933532715
Epoch: 25, loss: 0.6974166631698608, dp: 0.0017071962356567383
Epoch: 26, loss: 0.6961178779602051, dp: 0.001187741756439209
Epoch: 27, loss: 0.6935380101203918, dp: 0.00015437602996826172
Epoch: 28, loss: 0.6951382160186768, dp: 0.0013152360916137695
Epoch: 29, loss: 0.6966246366500854, dp: 0.002297043800354004
Epoch: 30, loss: 0.6974535584449768, dp: 0.002843618392944336
Epoch: 31, loss: 0.6976898908615112, dp: 0.0030000805854797363
Epoch: 32, loss: 0.697391152381897, dp: 0.0028044581413269043
Epoch: 33, loss: 0.6966089010238647, dp: 0.0022897720336914062
Epoch: 34, loss: 0.6953877210617065, dp: 0.0014837980270385742
Epoch: 35, loss: 0.6937685608863831, dp: 0.00041115283966064453
Epoch: 36, loss: 0.6954146027565002, dp: 0.000906825065612793
Epoch: 37, loss: 0.6969265937805176, dp: 0.0015110969543457031
Epoch: 38, loss: 0.6968545317649841, dp: 0.001482248306274414
Epoch: 39, loss: 0.6953892707824707, dp: 0.0008966922760009766
Epoch: 40, loss: 0.693409264087677, dp: 0.00017368793487548828
Epoch: 41, loss: 0.6943631172180176, dp: 0.000807642936706543
Epoch: 42, loss: 0.6947269439697266, dp: 0.001049339771270752
Epoch: 43, loss: 0.6945565342903137, dp: 0.0009363889694213867
Epoch: 44, loss: 0.6939018368721008, dp: 0.0005016326904296875
Epoch: 45, loss: 0.6937127113342285, dp: 0.00022614002227783203
Epoch: 46, loss: 0.6939553618431091, dp: 0.00032335519790649414
Epoch: 47, loss: 0.6933559775352478, dp: 0.00013881921768188477
Epoch: 48, loss: 0.6934863924980164, dp: 0.0002257227897644043
Epoch: 49, loss: 0.6932146549224854, dp: 2.6941299438476562e-05
Epoch: 50, loss: 0.6935896873474121, dp: 0.00029468536376953125
Epoch: 51, loss: 0.6935305595397949, dp: 0.00025534629821777344
Epoch: 52, loss: 0.6934249401092529, dp: 0.00011110305786132812
Epoch: 53, loss: 0.6933110952377319, dp: 0.0001093149185180664
Epoch: 54, loss: 0.6932008266448975, dp: 2.1457672119140625e-05
Epoch: 55, loss: 0.6937585473060608, dp: 0.00040733814239501953
Epoch: 56, loss: 0.6938467025756836, dp: 0.0004661083221435547
Epoch: 57, loss: 0.6934322118759155, dp: 0.0001900196075439453
Epoch: 58, loss: 0.6941245198249817, dp: 0.00039076805114746094
Epoch: 59, loss: 0.6940488815307617, dp: 0.0003604888916015625
Epoch: 60, loss: 0.6934659481048584, dp: 0.00021249055862426758
Epoch: 61, loss: 0.6937500834465027, dp: 0.0004018545150756836
Epoch: 62, loss: 0.6935136318206787, dp: 0.0002442598342895508
Epoch: 63, loss: 0.6937164068222046, dp: 0.0002276301383972168
Epoch: 64, loss: 0.6934037208557129, dp: 0.00010263919830322266
Epoch: 65, loss: 0.6939764618873596, dp: 0.0005524754524230957
Epoch: 66, loss: 0.6943734884262085, dp: 0.000816643238067627
Epoch: 67, loss: 0.6942408680915833, dp: 0.0007283687591552734
Epoch: 68, loss: 0.6936281323432922, dp: 0.0003205537796020508
Epoch: 69, loss: 0.6940924525260925, dp: 0.00037801265716552734
Epoch: 70, loss: 0.6942865252494812, dp: 0.00045561790466308594
Epoch: 71, loss: 0.6931780576705933, dp: 2.0623207092285156e-05
Epoch: 72, loss: 0.6933311820030212, dp: 0.00012260675430297852
Epoch: 73, loss: 0.6934307217597961, dp: 0.00011348724365234375
Epoch: 74, loss: 0.6934766173362732, dp: 0.0002194046974182129
Epoch: 75, loss: 0.6934362053871155, dp: 0.00019252300262451172
Epoch: 76, loss: 0.693547248840332, dp: 0.00016003847122192383
Epoch: 77, loss: 0.6932504177093506, dp: 6.872415542602539e-05
Epoch: 78, loss: 0.6932784914970398, dp: 5.257129669189453e-05
Epoch: 79, loss: 0.6937205791473389, dp: 0.00038182735443115234
Epoch: 80, loss: 0.6938185095787048, dp: 0.0004469156265258789
Epoch: 81, loss: 0.6934152841567993, dp: 0.0001785755157470703
Epoch: 82, loss: 0.6941288113594055, dp: 0.00039261579513549805
Epoch: 83, loss: 0.6940430402755737, dp: 0.00035834312438964844
Epoch: 84, loss: 0.6934714317321777, dp: 0.00021594762802124023
Epoch: 85, loss: 0.6937586069107056, dp: 0.00040709972381591797
Epoch: 86, loss: 0.6935272812843323, dp: 0.0002530813217163086
Epoch: 87, loss: 0.6936823129653931, dp: 0.00021409988403320312
Epoch: 88, loss: 0.6933677196502686, dp: 8.821487426757812e-05
Epoch: 89, loss: 0.6939976215362549, dp: 0.0005662441253662109
Epoch: 90, loss: 0.6943947076797485, dp: 0.0008302927017211914
Epoch: 91, loss: 0.6942631006240845, dp: 0.000742793083190918
Epoch: 92, loss: 0.693653404712677, dp: 0.00033724308013916016
Epoch: 93, loss: 0.6940414309501648, dp: 0.00035768747329711914
Epoch: 94, loss: 0.6942345499992371, dp: 0.00043487548828125
Epoch: 95, loss: 0.6932075023651123, dp: 4.029273986816406e-05
Epoch: 96, loss: 0.6933603882789612, dp: 0.00014215707778930664
Epoch: 97, loss: 0.6933803558349609, dp: 9.328126907348633e-05
Epoch: 98, loss: 0.6935058832168579, dp: 0.00023895502090454102
Epoch: 99, loss: 0.6934655904769897, dp: 0.00021207332611083984

Clarification on the exact value of the metric

Though it is generally agreed upon that demographic parity is achieved when the positive rates are equal, there is a lot of ambiguity on how to measure the extent to which demographic parity is violated. In fairret, we assess this violation by comparing the statistic values for each sensitive feature to the statistic value for the entire population.

In our specific example here, it just means that we compare the mean prediction value for each of the two groups to the overall mean prediction value.

There is then one more ingredient: how the gap between this value is actually computed. We provide a few options, such as the absolute difference (gap_abs_max) and the relative absolute difference (gap_relative_abs_max). The former takes the maximum of the L1 norm of the gap, while the latter divides this maximum by the overall mean statistic. This is the default behavior of the LinearFractionalParity class.

To use another gap function, simply pass it as an argument to the LinearFractionalParity class. Of course, you’re also free to implement your own!

What’s next?

In larger pipelines, you will likely want to define separate metrics for train, validation, and test set results.

Also, you may want to assess many fairness definitions at once. These could all be defined as separate metrics, or you could make use of the StackedLinearFractionalStatistic, which keeps track of many statistics at the same time (see Stacked Statistic.ipynb). However, keep in mind that you then won’t get scalar values out of the compute method, but a tensor that stacks the violations of all statistics.