Simple Pipeline

In this notebook, we show a full example of how the fairret library might be used to train a PyTorch model with a fairness cost.

Loading some data

To start, let’s load some data where fair binary classification is desirable. We’ll use the folktables library and their example data of the 2018 American Community Survey (ACS).

[1]:
from folktables import ACSDataSource

data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
data = data_source.get_data(states=["AL"], download=True)
We specifically address the ACSIncome task, where we predict whether an individual's income is above $50,000.
[2]:
from folktables import ACSIncome, generate_categories

definition_df = data_source.get_definitions(download=True)
categories = generate_categories(features=ACSIncome.features, definition_df=definition_df)

df_feat, df_labels, _ = ACSIncome.df_to_pandas(data, categories=categories, dummies=True)
df_feat.head()
[2]:
AGEP WKHP COW_Employee of a private for-profit company or business, or of an individual, for wages, salary, or commissions COW_Employee of a private not-for-profit, tax-exempt, or charitable organization COW_Federal government employee COW_Local government employee (city, county, etc.) COW_Self-employed in own incorporated business, professional practice or farm COW_Self-employed in own not incorporated business, professional practice, or farm COW_State government employee COW_Working without pay in family business or farm ... SEX_Male RAC1P_Alaska Native alone RAC1P_American Indian alone RAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races RAC1P_Asian alone RAC1P_Black or African American alone RAC1P_Native Hawaiian and Other Pacific Islander alone RAC1P_Some Other Race alone RAC1P_Two or More Races RAC1P_White alone
0 18 21.0 True False False False False False False False ... False False False False False True False False False False
1 53 40.0 False False True False False False False False ... True False False False False False False False False True
2 41 40.0 True False False False False False False False ... True False False False False False False False False True
3 18 2.0 False False False False False True False False ... False False False False False False False False False True
4 21 50.0 False False True False False False False False ... True False False False False False False False False True

5 rows × 729 columns

To keep things simple for now, let’s only consider two sensitive groups: male and female.

[3]:
sens_cols = ['SEX_Female', 'SEX_Male']
feat = df_feat.drop(columns=sens_cols).to_numpy(dtype="float")
sens = df_feat[sens_cols].to_numpy(dtype="float")
label = df_labels.to_numpy(dtype="float")

print(sens.mean(axis=0))
[0.47808514 0.52191486]

A naive PyTorch pipeline

The fairret library treats sensitive features in the same way ‘normal’ features are treated in PyTorch: as (N x D) tensors, where N is the number of samples and D the dimensionality. In contrast to other fairness libraries you may have used, we can therefore just leave categorical sensitive features as one-hot encoded!

[4]:
import torch
torch.manual_seed(0)
feat, sens, label = torch.tensor(feat).float(), torch.tensor(sens).float(), torch.tensor(label).float()
print(f"Shape of the 'normal' features tensor: {feat.shape}")
print(f"Shape of the sensitive features tensor: {sens.shape}")
print(f"Shape of the labels tensor: {label.shape}")
Shape of the 'normal' features tensor: torch.Size([22268, 727])
Shape of the sensitive features tensor: torch.Size([22268, 2])
Shape of the labels tensor: torch.Size([22268, 1])

In typical PyTorch fashion, let’s now define a simple neural net with 1 hidden layer, an optimizer, and a DataLoader.

[5]:
h_layer_dim = 16
lr = 1e-3
batch_size = 1024

model = torch.nn.Sequential(
    torch.nn.Linear(feat.shape[1], h_layer_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_layer_dim, 1)
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

Now, let’s train it without doing any fairness adjustment…

[6]:
import numpy as np

nb_epochs = 50

for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")
Epoch: 0, loss: 0.6495096764781259
Epoch: 1, loss: 0.631090676242655
Epoch: 2, loss: 0.6120786558498036
Epoch: 3, loss: 0.5901886902072213
Epoch: 4, loss: 0.5662552903998982
Epoch: 5, loss: 0.5412234474312175
Epoch: 6, loss: 0.5169245397502725
Epoch: 7, loss: 0.4951953955672004
Epoch: 8, loss: 0.4771566512909802
Epoch: 9, loss: 0.4624161679636348
Epoch: 10, loss: 0.45037723264910956
Epoch: 11, loss: 0.4405312815850431
Epoch: 12, loss: 0.43243050778454
Epoch: 13, loss: 0.42573171854019165
Epoch: 14, loss: 0.4201531301845204
Epoch: 15, loss: 0.4154460768808018
Epoch: 16, loss: 0.41148471154949884
Epoch: 17, loss: 0.40810865570198407
Epoch: 18, loss: 0.4051740806211125
Epoch: 19, loss: 0.4026364270936359
Epoch: 20, loss: 0.400410919026895
Epoch: 21, loss: 0.39844207533381204
Epoch: 22, loss: 0.3967087905515324
Epoch: 23, loss: 0.3951598744500767
Epoch: 24, loss: 0.3937697451223027
Epoch: 25, loss: 0.39251258156516333
Epoch: 26, loss: 0.3913724842396649
Epoch: 27, loss: 0.3903204453262416
Epoch: 28, loss: 0.3893602354960008
Epoch: 29, loss: 0.38850402154705743
Epoch: 30, loss: 0.3876958462325009
Epoch: 31, loss: 0.38694334572011774
Epoch: 32, loss: 0.3862581164999442
Epoch: 33, loss: 0.38562133434143936
Epoch: 34, loss: 0.3850080926309932
Epoch: 35, loss: 0.38443617725914175
Epoch: 36, loss: 0.38391522047194565
Epoch: 37, loss: 0.38343126529997046
Epoch: 38, loss: 0.3829596929929473
Epoch: 39, loss: 0.3825045864690434
Epoch: 40, loss: 0.3820722380822355
Epoch: 41, loss: 0.38168060102246026
Epoch: 42, loss: 0.38129559497941623
Epoch: 43, loss: 0.3809025511145592
Epoch: 44, loss: 0.3805498616261916
Epoch: 45, loss: 0.38022591309113934
Epoch: 46, loss: 0.3798875537785617
Epoch: 47, loss: 0.379513679580255
Epoch: 48, loss: 0.3791568949818611
Epoch: 49, loss: 0.37884155728600244

Bias analysis in fairret

Can we detect any statistical disparities (biases) in the naive model?

The fairret library assesses these biases by comparing a (linear-fractional) Statistic computed for each sensitive features. In our example, this is for the ‘SEX_Female’ and ‘SEX_Male’ features. For example, let’s look at the true positive rate (= the recall or sensitivity).

[7]:
from fairret.statistic import TruePositiveRate

statistic = TruePositiveRate()

pred = torch.sigmoid(model(feat))
stat_per_group = statistic(pred, sens, label)
absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])

print(f"The {statistic.__class__.__name__} for group {sens_cols[0]} is {stat_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group {sens_cols[1]} is {stat_per_group[1]}")
print(f"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}")
The TruePositiveRate for group SEX_Female is 0.5624983310699463
The TruePositiveRate for group SEX_Male is 0.6300471425056458
The absolute difference is 0.06754881143569946

Bias mitigation in fairret

To reduce the statistical disparity we found, we can use one of the fairrets implemented in the library. To quantify bias according to the correct statistic, we need to pass the statistic object to the fairret loss.

[8]:
from fairret.loss import NormLoss

norm_loss = NormLoss(statistic)

Let’s train another model where we now add this loss term to the objective.

We only need to add one line of code to the standard PyTorch training loop!

[9]:
fairness_strength = 0.1
model = torch.nn.Sequential(
    torch.nn.Linear(feat.shape[1], h_layer_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_layer_dim, 1)
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")
Epoch: 0, loss: 0.6555492119355635
Epoch: 1, loss: 0.6289367729967291
Epoch: 2, loss: 0.6124052486636422
Epoch: 3, loss: 0.5947602743452246
Epoch: 4, loss: 0.5761971880089153
Epoch: 5, loss: 0.5571780936284498
Epoch: 6, loss: 0.538903527639129
Epoch: 7, loss: 0.5221295695413243
Epoch: 8, loss: 0.5069978806105527
Epoch: 9, loss: 0.4935957017270001
Epoch: 10, loss: 0.48187787559899414
Epoch: 11, loss: 0.4717836068435149
Epoch: 12, loss: 0.46291052482344885
Epoch: 13, loss: 0.4552074447274208
Epoch: 14, loss: 0.44845194775949826
Epoch: 15, loss: 0.4425306184725328
Epoch: 16, loss: 0.43727567114613275
Epoch: 17, loss: 0.4326496706767516
Epoch: 18, loss: 0.4285395748235963
Epoch: 19, loss: 0.42487049102783203
Epoch: 20, loss: 0.42168024250052194
Epoch: 21, loss: 0.418713163245808
Epoch: 22, loss: 0.4161644611846317
Epoch: 23, loss: 0.41383336958560074
Epoch: 24, loss: 0.4117328572002324
Epoch: 25, loss: 0.4098751870068637
Epoch: 26, loss: 0.4081542722203515
Epoch: 27, loss: 0.4065771651538936
Epoch: 28, loss: 0.40517762438817456
Epoch: 29, loss: 0.40386907011270523
Epoch: 30, loss: 0.40267594226382
Epoch: 31, loss: 0.4015895886854692
Epoch: 32, loss: 0.40053029832514847
Epoch: 33, loss: 0.3996298699216409
Epoch: 34, loss: 0.398726802657951
Epoch: 35, loss: 0.39790612052787433
Epoch: 36, loss: 0.3971890244971622
Epoch: 37, loss: 0.3964132788506421
Epoch: 38, loss: 0.3957923088561405
Epoch: 39, loss: 0.3951647024263035
Epoch: 40, loss: 0.39456393027847464
Epoch: 41, loss: 0.3940161994912408
Epoch: 42, loss: 0.39341962202028796
Epoch: 43, loss: 0.39296053688634525
Epoch: 44, loss: 0.3924828178503297
Epoch: 45, loss: 0.3919635293158618
Epoch: 46, loss: 0.3915071880275553
Epoch: 47, loss: 0.39117465845563193
Epoch: 48, loss: 0.390705829994245
Epoch: 49, loss: 0.3903096318244934

Let’s check the true positive rate per group again…

[10]:
pred = torch.sigmoid(model(feat))
stat_per_group = statistic(pred, sens, label)

print(f"The {statistic.__class__.__name__} for group {sens_cols[0]} is {stat_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group {sens_cols[1]} is {stat_per_group[1]}")
print(f"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}")
The TruePositiveRate for group SEX_Female is 0.5829195976257324
The TruePositiveRate for group SEX_Male is 0.6023395657539368
The absolute difference is 0.019419968128204346

With a small change, the absolute difference between the statistics was reduced from 6.8% to 1.9% !

Though this was a simple example, it illustrates how powerful the fairret paradigm can be.

Feel free to go back and try out some other statistics to compare or fairret losses to minimize. Both are designed to be easily interchangeable and extensible.