Custom Statistic

In this notebook, we’ll implement a new statistic as a subclass of LinearFractionalStatistic. All fairrets work with any LinearFractionalStatistic, so they will work with out new statistic as well.

We take inspiration from the paper “Generalizing Group Fairness in Machine Learning via Utilities” by Blandin and Kash. For the well-known German Credit Dataset, they propose the following cost for predictions \(\hat{Y}\) and ground truth labels \(Y\):

\[\begin{split}C = \begin{cases} 0 & \text{ if } \hat{Y} = Y \\ 1 & \text{ if } \hat{Y} = 0 \wedge Y = 1 \\ 5 & \text{ if } \hat{Y} = 1 \wedge Y = 0 \end{cases}\end{split}\]

The costs are motivated by the fact that a loan applicant that receives a loan \((\hat{Y} = 1)\) but will not repay it \((Y = 0)\) will have to default, which is considered far worse than when an applicant is rejected \((\hat{Y} = 0)\) that would have repaid \((Y = 1)\) the loan.

The statistic in this case is the average cost \(C\) incurred over all individuals in a sensitive group. Hence, the statistic is canonically formalized as

\[\gamma(k, f) = \frac{\mathbb{E}[SC]}{\mathbb{E}[S]} = \frac{\mathbb{E}[S(1 Y(1 - f(X)) + 5 (1 - Y)f(X))]}{\mathbb{E}[S]} = \frac{\mathbb{E}[S(Y + (5 - 6Y)f(X))]}{\mathbb{E}[S]}\]

where we filled in \(\hat{Y}\) with the probabilistic \(f(X)\).

The canonical form allows us to identify how the statistic has a linear-fractional form with respect to \(f\). Ignoring \(S\) for a moment, the intercept of the numerator is \(Y\) and the slope is \((5 - 6Y)\). The denominator is not dependent on \(f\).

The statistic is then implemented as:

[1]:
import torch
from fairret.statistic import LinearFractionalStatistic

class CustomCost(LinearFractionalStatistic):
    def num_intercept(self, label: torch.Tensor) -> torch.Tensor:
        return label

    def num_slope(self, label: torch.Tensor) -> torch.Tensor:
        return 5 - 6 * label

    def denom_intercept(self, label: torch.Tensor) -> torch.Tensor:
        return 1

    def denom_slope(self, label: torch.Tensor) -> torch.Tensor:
        return 0.

Let’s quickly try it out…

[2]:
import torch
torch.manual_seed(0)

feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])
label = torch.tensor([[0.], [1.], [0.], [1.]])

from fairret.loss import NormLoss

statistic = CustomCost()
norm_loss = NormLoss(statistic)

h_layer_dim = 16
lr = 1e-3
batch_size = 1024

def build_model():
    model = torch.nn.Sequential(
        torch.nn.Linear(feat.shape[1], h_layer_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(h_layer_dim, 1)
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    return model, optimizer

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

Without fairret…

[3]:
import numpy as np

nb_epochs = 100
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

pred = torch.sigmoid(model(feat))
stat_per_group = statistic(pred, sens, label)
absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])

print(f"The {statistic.__class__.__name__} for group 0 is {stat_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group 1 is {stat_per_group[1]}")
print(f"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}")
Epoch: 0, loss: 0.7091795206069946
Epoch: 1, loss: 0.7061765193939209
Epoch: 2, loss: 0.7033581733703613
Epoch: 3, loss: 0.7007156610488892
Epoch: 4, loss: 0.6982340812683105
Epoch: 5, loss: 0.6959078907966614
Epoch: 6, loss: 0.6937355995178223
Epoch: 7, loss: 0.6917158365249634
Epoch: 8, loss: 0.6898466944694519
Epoch: 9, loss: 0.6881252527236938
Epoch: 10, loss: 0.6865478754043579
Epoch: 11, loss: 0.6851094961166382
Epoch: 12, loss: 0.6838041543960571
Epoch: 13, loss: 0.6826250553131104
Epoch: 14, loss: 0.6815641522407532
Epoch: 15, loss: 0.6806124448776245
Epoch: 16, loss: 0.6797604560852051
Epoch: 17, loss: 0.6789975762367249
Epoch: 18, loss: 0.6783132553100586
Epoch: 19, loss: 0.6776963472366333
Epoch: 20, loss: 0.6771360039710999
Epoch: 21, loss: 0.6766215562820435
Epoch: 22, loss: 0.6761429309844971
Epoch: 23, loss: 0.6756909489631653
Epoch: 24, loss: 0.6752569675445557
Epoch: 25, loss: 0.6748337745666504
Epoch: 26, loss: 0.674415111541748
Epoch: 27, loss: 0.673996090888977
Epoch: 28, loss: 0.6735726594924927
Epoch: 29, loss: 0.6731564998626709
Epoch: 30, loss: 0.6727579236030579
Epoch: 31, loss: 0.672345757484436
Epoch: 32, loss: 0.6719199419021606
Epoch: 33, loss: 0.6714813709259033
Epoch: 34, loss: 0.6710319519042969
Epoch: 35, loss: 0.6705741882324219
Epoch: 36, loss: 0.6701083779335022
Epoch: 37, loss: 0.669636607170105
Epoch: 38, loss: 0.6691610217094421
Epoch: 39, loss: 0.6686834096908569
Epoch: 40, loss: 0.6682056188583374
Epoch: 41, loss: 0.6677289009094238
Epoch: 42, loss: 0.667254626750946
Epoch: 43, loss: 0.6667835712432861
Epoch: 44, loss: 0.6663164496421814
Epoch: 45, loss: 0.6658533811569214
Epoch: 46, loss: 0.6653945446014404
Epoch: 47, loss: 0.6649397015571594
Epoch: 48, loss: 0.6644884347915649
Epoch: 49, loss: 0.6640403270721436
Epoch: 50, loss: 0.6635947227478027
Epoch: 51, loss: 0.6631510257720947
Epoch: 52, loss: 0.6628453135490417
Epoch: 53, loss: 0.6625917553901672
Epoch: 54, loss: 0.6623181104660034
Epoch: 55, loss: 0.6620256900787354
Epoch: 56, loss: 0.6617173552513123
Epoch: 57, loss: 0.6614043116569519
Epoch: 58, loss: 0.6610796451568604
Epoch: 59, loss: 0.6607442498207092
Epoch: 60, loss: 0.6603990793228149
Epoch: 61, loss: 0.6600450277328491
Epoch: 62, loss: 0.6596829295158386
Epoch: 63, loss: 0.6593135595321655
Epoch: 64, loss: 0.6589376330375671
Epoch: 65, loss: 0.6585558652877808
Epoch: 66, loss: 0.6581688523292542
Epoch: 67, loss: 0.6577771306037903
Epoch: 68, loss: 0.6574320793151855
Epoch: 69, loss: 0.6571431756019592
Epoch: 70, loss: 0.6568371653556824
Epoch: 71, loss: 0.6565203666687012
Epoch: 72, loss: 0.6561905145645142
Epoch: 73, loss: 0.6558488607406616
Epoch: 74, loss: 0.65549635887146
Epoch: 75, loss: 0.6551340818405151
Epoch: 76, loss: 0.6547629237174988
Epoch: 77, loss: 0.6544535160064697
Epoch: 78, loss: 0.6541627645492554
Epoch: 79, loss: 0.6538523435592651
Epoch: 80, loss: 0.6535260677337646
Epoch: 81, loss: 0.6531944274902344
Epoch: 82, loss: 0.6528521776199341
Epoch: 83, loss: 0.6525000333786011
Epoch: 84, loss: 0.652138888835907
Epoch: 85, loss: 0.6518597602844238
Epoch: 86, loss: 0.6515651345252991
Epoch: 87, loss: 0.6512539982795715
Epoch: 88, loss: 0.6509299874305725
Epoch: 89, loss: 0.650594174861908
Epoch: 90, loss: 0.6502466797828674
Epoch: 91, loss: 0.6498894691467285
Epoch: 92, loss: 0.6495950818061829
Epoch: 93, loss: 0.6493034362792969
Epoch: 94, loss: 0.6489962339401245
Epoch: 95, loss: 0.6486777067184448
Epoch: 96, loss: 0.6483432650566101
Epoch: 97, loss: 0.6479994058609009
Epoch: 98, loss: 0.6476455330848694
Epoch: 99, loss: 0.6473514437675476
The CustomCost for group 0 is 1.4111454486846924
The CustomCost for group 1 is 1.690650224685669
The absolute difference is 0.27950477600097656

With fairret…

[4]:
import numpy as np

nb_epochs = 100
fairness_strength = 1
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

pred = torch.sigmoid(model(feat))
stat_per_group = statistic(pred, sens, label)
absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])

print(f"The {statistic.__class__.__name__} for group 0 is {stat_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group 1 is {stat_per_group[1]}")
print(f"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}")
Epoch: 0, loss: 0.7234874963760376
Epoch: 1, loss: 0.7193881869316101
Epoch: 2, loss: 0.7153821587562561
Epoch: 3, loss: 0.7114719748497009
Epoch: 4, loss: 0.7076587677001953
Epoch: 5, loss: 0.703943133354187
Epoch: 6, loss: 0.7003254294395447
Epoch: 7, loss: 0.6968045234680176
Epoch: 8, loss: 0.6933779120445251
Epoch: 9, loss: 0.7009283900260925
Epoch: 10, loss: 0.70442134141922
Epoch: 11, loss: 0.7051329016685486
Epoch: 12, loss: 0.7039238810539246
Epoch: 13, loss: 0.7013260126113892
Epoch: 14, loss: 0.6976962685585022
Epoch: 15, loss: 0.693289577960968
Epoch: 16, loss: 0.6954131722450256
Epoch: 17, loss: 0.6971543431282043
Epoch: 18, loss: 0.6984840035438538
Epoch: 19, loss: 0.6994330883026123
Epoch: 20, loss: 0.7000323534011841
Epoch: 21, loss: 0.700312077999115
Epoch: 22, loss: 0.7003009915351868
Epoch: 23, loss: 0.7000272870063782
Epoch: 24, loss: 0.699517011642456
Epoch: 25, loss: 0.6987953782081604
Epoch: 26, loss: 0.6978861689567566
Epoch: 27, loss: 0.6968110203742981
Epoch: 28, loss: 0.6955899596214294
Epoch: 29, loss: 0.6942419409751892
Epoch: 30, loss: 0.6940961480140686
Epoch: 31, loss: 0.6957410573959351
Epoch: 32, loss: 0.6959663033485413
Epoch: 33, loss: 0.6949725151062012
Epoch: 34, loss: 0.6932772397994995
Epoch: 35, loss: 0.6938610076904297
Epoch: 36, loss: 0.6941598057746887
Epoch: 37, loss: 0.6941966414451599
Epoch: 38, loss: 0.6939942836761475
Epoch: 39, loss: 0.6935751438140869
Epoch: 40, loss: 0.6936072707176208
Epoch: 41, loss: 0.6936568021774292
Epoch: 42, loss: 0.6934249401092529
Epoch: 43, loss: 0.6936383843421936
Epoch: 44, loss: 0.6935983896255493
Epoch: 45, loss: 0.6933280825614929
Epoch: 46, loss: 0.6938256025314331
Epoch: 47, loss: 0.693622350692749
Epoch: 48, loss: 0.69352787733078
Epoch: 49, loss: 0.693831205368042
Epoch: 50, loss: 0.6938722133636475
Epoch: 51, loss: 0.693673849105835
Epoch: 52, loss: 0.6932585835456848
Epoch: 53, loss: 0.694254457950592
Epoch: 54, loss: 0.6943134665489197
Epoch: 55, loss: 0.6932371258735657
Epoch: 56, loss: 0.6940481066703796
Epoch: 57, loss: 0.6946710348129272
Epoch: 58, loss: 0.695000171661377
Epoch: 59, loss: 0.6950598955154419
Epoch: 60, loss: 0.6948744654655457
Epoch: 61, loss: 0.6944674253463745
Epoch: 62, loss: 0.6938610672950745
Epoch: 63, loss: 0.693301260471344
Epoch: 64, loss: 0.6937004923820496
Epoch: 65, loss: 0.6932454705238342
Epoch: 66, loss: 0.6933281421661377
Epoch: 67, loss: 0.6931682229042053
Epoch: 68, loss: 0.6939371824264526
Epoch: 69, loss: 0.6935523152351379
Epoch: 70, loss: 0.6936286687850952
Epoch: 71, loss: 0.693997859954834
Epoch: 72, loss: 0.6940965056419373
Epoch: 73, loss: 0.6939487457275391
Epoch: 74, loss: 0.6935782432556152
Epoch: 75, loss: 0.6934525370597839
Epoch: 76, loss: 0.6934391856193542
Epoch: 77, loss: 0.6935307383537292
Epoch: 78, loss: 0.6937640905380249
Epoch: 79, loss: 0.6937397718429565
Epoch: 80, loss: 0.6934816837310791
Epoch: 81, loss: 0.6934418678283691
Epoch: 82, loss: 0.6932307481765747
Epoch: 83, loss: 0.6937059760093689
Epoch: 84, loss: 0.6940118670463562
Epoch: 85, loss: 0.6940526366233826
Epoch: 86, loss: 0.693852961063385
Epoch: 87, loss: 0.6934364438056946
Epoch: 88, loss: 0.6938560009002686
Epoch: 89, loss: 0.6939221024513245
Epoch: 90, loss: 0.6932809352874756
Epoch: 91, loss: 0.6934866309165955
Epoch: 92, loss: 0.6934378147125244
Epoch: 93, loss: 0.6931588053703308
Epoch: 94, loss: 0.6941953897476196
Epoch: 95, loss: 0.6940174698829651
Epoch: 96, loss: 0.6933366656303406
Epoch: 97, loss: 0.6936314105987549
Epoch: 98, loss: 0.6936633586883545
Epoch: 99, loss: 0.6934562921524048
The CustomCost for group 0 is 1.4996007680892944
The CustomCost for group 1 is 1.4993300437927246
The absolute difference is 0.0002707242965698242