Stacked Statistic

In this notebook, we’ll show the use of the StackedLinearFractionalStatistic, which is a helpful class to combine multiple linear-fractional statistics. For example, the well-known fairness definition of equalised odds enforces equality in both the true positive rate and the false positive rate across sensitive groups. Hence, we can keep track of these statistics in a single vector-valued statistic:

[1]:
from fairret.statistic import TruePositiveRate, FalsePositiveRate, StackedLinearFractionalStatistic

equalised_odds_stats = StackedLinearFractionalStatistic(TruePositiveRate(), FalsePositiveRate())

Let’s quickly try it out…

[2]:
import torch
torch.manual_seed(0)

feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])
label = torch.tensor([[0.], [1.], [0.], [1.]])

from fairret.loss import NormLoss

norm_loss = NormLoss(equalised_odds_stats)

h_layer_dim = 16
lr = 1e-3
batch_size = 1024

def build_model():
    _model = torch.nn.Sequential(
        torch.nn.Linear(feat.shape[1], h_layer_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(h_layer_dim, 1)
    )
    _optimizer = torch.optim.Adam(_model.parameters(), lr=lr)
    return _model, _optimizer

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

Without fairret…

[3]:
import numpy as np

nb_epochs = 100
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

pred = torch.sigmoid(model(feat))
eo_per_group = equalised_odds_stats(pred, sens, label)
absolute_diff = torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])

print(f"The TPR and FPR for group 0 are {eo_per_group[:, 0]}")
print(f"The TPR and FPR for group 1 are {eo_per_group[:, 1]}")
print(f"The absolute differences are {torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])}")
Epoch: 0, loss: 0.7091795206069946
Epoch: 1, loss: 0.7061765193939209
Epoch: 2, loss: 0.7033581733703613
Epoch: 3, loss: 0.7007156610488892
Epoch: 4, loss: 0.6982340812683105
Epoch: 5, loss: 0.6959078907966614
Epoch: 6, loss: 0.6937355995178223
Epoch: 7, loss: 0.6917158365249634
Epoch: 8, loss: 0.6898466944694519
Epoch: 9, loss: 0.6881252527236938
Epoch: 10, loss: 0.6865478754043579
Epoch: 11, loss: 0.6851094961166382
Epoch: 12, loss: 0.6838041543960571
Epoch: 13, loss: 0.6826250553131104
Epoch: 14, loss: 0.6815641522407532
Epoch: 15, loss: 0.6806124448776245
Epoch: 16, loss: 0.6797604560852051
Epoch: 17, loss: 0.6789975762367249
Epoch: 18, loss: 0.6783132553100586
Epoch: 19, loss: 0.6776963472366333
Epoch: 20, loss: 0.6771360039710999
Epoch: 21, loss: 0.6766215562820435
Epoch: 22, loss: 0.6761429309844971
Epoch: 23, loss: 0.6756909489631653
Epoch: 24, loss: 0.6752569675445557
Epoch: 25, loss: 0.6748337745666504
Epoch: 26, loss: 0.674415111541748
Epoch: 27, loss: 0.673996090888977
Epoch: 28, loss: 0.6735726594924927
Epoch: 29, loss: 0.6731564998626709
Epoch: 30, loss: 0.6727579236030579
Epoch: 31, loss: 0.672345757484436
Epoch: 32, loss: 0.6719199419021606
Epoch: 33, loss: 0.6714813709259033
Epoch: 34, loss: 0.6710319519042969
Epoch: 35, loss: 0.6705741882324219
Epoch: 36, loss: 0.6701083779335022
Epoch: 37, loss: 0.669636607170105
Epoch: 38, loss: 0.6691610217094421
Epoch: 39, loss: 0.6686834096908569
Epoch: 40, loss: 0.6682056188583374
Epoch: 41, loss: 0.6677289009094238
Epoch: 42, loss: 0.667254626750946
Epoch: 43, loss: 0.6667835712432861
Epoch: 44, loss: 0.6663164496421814
Epoch: 45, loss: 0.6658533811569214
Epoch: 46, loss: 0.6653945446014404
Epoch: 47, loss: 0.6649397015571594
Epoch: 48, loss: 0.6644884347915649
Epoch: 49, loss: 0.6640403270721436
Epoch: 50, loss: 0.6635947227478027
Epoch: 51, loss: 0.6631510257720947
Epoch: 52, loss: 0.6628453135490417
Epoch: 53, loss: 0.6625917553901672
Epoch: 54, loss: 0.6623181104660034
Epoch: 55, loss: 0.6620256900787354
Epoch: 56, loss: 0.6617173552513123
Epoch: 57, loss: 0.6614043116569519
Epoch: 58, loss: 0.6610796451568604
Epoch: 59, loss: 0.6607442498207092
Epoch: 60, loss: 0.6603990793228149
Epoch: 61, loss: 0.6600450277328491
Epoch: 62, loss: 0.6596829295158386
Epoch: 63, loss: 0.6593135595321655
Epoch: 64, loss: 0.6589376330375671
Epoch: 65, loss: 0.6585558652877808
Epoch: 66, loss: 0.6581688523292542
Epoch: 67, loss: 0.6577771306037903
Epoch: 68, loss: 0.6574320793151855
Epoch: 69, loss: 0.6571431756019592
Epoch: 70, loss: 0.6568371653556824
Epoch: 71, loss: 0.6565203666687012
Epoch: 72, loss: 0.6561905145645142
Epoch: 73, loss: 0.6558488607406616
Epoch: 74, loss: 0.65549635887146
Epoch: 75, loss: 0.6551340818405151
Epoch: 76, loss: 0.6547629237174988
Epoch: 77, loss: 0.6544535160064697
Epoch: 78, loss: 0.6541627645492554
Epoch: 79, loss: 0.6538523435592651
Epoch: 80, loss: 0.6535260677337646
Epoch: 81, loss: 0.6531944274902344
Epoch: 82, loss: 0.6528521776199341
Epoch: 83, loss: 0.6525000333786011
Epoch: 84, loss: 0.652138888835907
Epoch: 85, loss: 0.6518597602844238
Epoch: 86, loss: 0.6515651345252991
Epoch: 87, loss: 0.6512539982795715
Epoch: 88, loss: 0.6509299874305725
Epoch: 89, loss: 0.650594174861908
Epoch: 90, loss: 0.6502466797828674
Epoch: 91, loss: 0.6498894691467285
Epoch: 92, loss: 0.6495950818061829
Epoch: 93, loss: 0.6493034362792969
Epoch: 94, loss: 0.6489962339401245
Epoch: 95, loss: 0.6486777067184448
Epoch: 96, loss: 0.6483432650566101
Epoch: 97, loss: 0.6479994058609009
Epoch: 98, loss: 0.6476455330848694
Epoch: 99, loss: 0.6473514437675476
The TPR and FPR for group 0 are tensor([0.5537, 0.4752], grad_fn=<SelectBackward0>)
The TPR and FPR for group 1 are tensor([0.6604, 0.6083], grad_fn=<SelectBackward0>)
The absolute differences are tensor([0.1067, 0.1331], grad_fn=<AbsBackward0>)

With fairret…

[4]:
import numpy as np

nb_epochs = 100
fairness_strength = 1
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

pred = torch.sigmoid(model(feat))
eo_per_group = equalised_odds_stats(pred, sens, label)
absolute_diff = torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])

print(f"The TPR and FPR for group 0 are {eo_per_group[:, 0]}")
print(f"The TPR and FPR for group 1 are {eo_per_group[:, 1]}")
print(f"The absolute differences are {torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])}")
Epoch: 0, loss: 0.8069422245025635
Epoch: 1, loss: 0.7932361960411072
Epoch: 2, loss: 0.7793688178062439
Epoch: 3, loss: 0.7653393149375916
Epoch: 4, loss: 0.7511466145515442
Epoch: 5, loss: 0.7367900013923645
Epoch: 6, loss: 0.7222684025764465
Epoch: 7, loss: 0.7075802683830261
Epoch: 8, loss: 0.693701446056366
Epoch: 9, loss: 0.7050990462303162
Epoch: 10, loss: 0.7107114195823669
Epoch: 11, loss: 0.7118827104568481
Epoch: 12, loss: 0.7095987200737
Epoch: 13, loss: 0.704603374004364
Epoch: 14, loss: 0.6974690556526184
Epoch: 15, loss: 0.6966591477394104
Epoch: 16, loss: 0.70110023021698
Epoch: 17, loss: 0.7034574151039124
Epoch: 18, loss: 0.7040092945098877
Epoch: 19, loss: 0.7029833793640137
Epoch: 20, loss: 0.7005670070648193
Epoch: 21, loss: 0.6969154477119446
Epoch: 22, loss: 0.6944364905357361
Epoch: 23, loss: 0.6974213719367981
Epoch: 24, loss: 0.6976830959320068
Epoch: 25, loss: 0.695574164390564
Epoch: 26, loss: 0.6945269703865051
Epoch: 27, loss: 0.6960493922233582
Epoch: 28, loss: 0.6960320472717285
Epoch: 29, loss: 0.6946355700492859
Epoch: 30, loss: 0.6946574449539185
Epoch: 31, loss: 0.6954229474067688
Epoch: 32, loss: 0.6938773989677429
Epoch: 33, loss: 0.6954055428504944
Epoch: 34, loss: 0.6965784430503845
Epoch: 35, loss: 0.6962950229644775
Epoch: 36, loss: 0.6946983933448792
Epoch: 37, loss: 0.6947646141052246
Epoch: 38, loss: 0.6957594752311707
Epoch: 39, loss: 0.6944624781608582
Epoch: 40, loss: 0.6947464942932129
Epoch: 41, loss: 0.6957638263702393
Epoch: 42, loss: 0.6953589916229248
Epoch: 43, loss: 0.6936694383621216
Epoch: 44, loss: 0.6961650252342224
Epoch: 45, loss: 0.6972649693489075
Epoch: 46, loss: 0.6960750222206116
Epoch: 47, loss: 0.6933992505073547
Epoch: 48, loss: 0.6943512558937073
Epoch: 49, loss: 0.6938944458961487
Epoch: 50, loss: 0.6944283843040466
Epoch: 51, loss: 0.6942562460899353
Epoch: 52, loss: 0.6940963864326477
Epoch: 53, loss: 0.6944071054458618
Epoch: 54, loss: 0.693374752998352
Epoch: 55, loss: 0.6957550048828125
Epoch: 56, loss: 0.6961789727210999
Epoch: 57, loss: 0.6943976879119873
Epoch: 58, loss: 0.6950995922088623
Epoch: 59, loss: 0.6964230537414551
Epoch: 60, loss: 0.6963126063346863
Epoch: 61, loss: 0.6949063539505005
Epoch: 62, loss: 0.6942113041877747
Epoch: 63, loss: 0.6950266361236572
Epoch: 64, loss: 0.6936017870903015
Epoch: 65, loss: 0.6954660415649414
Epoch: 66, loss: 0.6965658664703369
Epoch: 67, loss: 0.696255624294281
Epoch: 68, loss: 0.6946702003479004
Epoch: 69, loss: 0.694717288017273
Epoch: 70, loss: 0.6957194209098816
Epoch: 71, loss: 0.6944576501846313
Epoch: 72, loss: 0.6946854591369629
Epoch: 73, loss: 0.6956840753555298
Epoch: 74, loss: 0.6952812671661377
Epoch: 75, loss: 0.6936118006706238
Epoch: 76, loss: 0.6961743235588074
Epoch: 77, loss: 0.6972628831863403
Epoch: 78, loss: 0.6960713267326355
Epoch: 79, loss: 0.6933831572532654
Epoch: 80, loss: 0.6943389773368835
Epoch: 81, loss: 0.6938958168029785
Epoch: 82, loss: 0.6943807601928711
Epoch: 83, loss: 0.6941943168640137
Epoch: 84, loss: 0.6941384673118591
Epoch: 85, loss: 0.6944611668586731
Epoch: 86, loss: 0.6934477686882019
Epoch: 87, loss: 0.6956196427345276
Epoch: 88, loss: 0.6960253715515137
Epoch: 89, loss: 0.6942300200462341
Epoch: 90, loss: 0.6952304840087891
Epoch: 91, loss: 0.6965633630752563
Epoch: 92, loss: 0.6964669823646545
Epoch: 93, loss: 0.6950795650482178
Epoch: 94, loss: 0.693950891494751
Epoch: 95, loss: 0.6947492361068726
Epoch: 96, loss: 0.6933133602142334
Epoch: 97, loss: 0.6956915855407715
Epoch: 98, loss: 0.6967988014221191
Epoch: 99, loss: 0.6964995861053467
The TPR and FPR for group 0 are tensor([0.5012, 0.5009], grad_fn=<SelectBackward0>)
The TPR and FPR for group 1 are tensor([0.5017, 0.5014], grad_fn=<SelectBackward0>)
The absolute differences are tensor([0.0005, 0.0005], grad_fn=<AbsBackward0>)