fairlib Evaluation Tutorial
1. Installation
[1]:
!pip install fairlib
Collecting fairlib
Downloading fairlib-0.0.3-py3-none-any.whl (63 kB)
|████████████████████████████████| 63 kB 2.0 MB/s
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.21.6)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.0.2)
Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from fairlib) (0.11.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from fairlib) (4.64.0)
Requirement already satisfied: docopt in /usr/local/lib/python3.7/dist-packages (from fairlib) (0.6.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.3.5)
Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from fairlib) (3.13)
Collecting pickle5
Downloading pickle5-0.0.12-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (256 kB)
|████████████████████████████████| 256 kB 29.1 MB/s
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from fairlib) (3.2.2)
Collecting transformers
Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
|████████████████████████████████| 4.0 MB 55.2 MB/s
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.11.0+cu113)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (1.4.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (0.11.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (3.0.8)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->fairlib) (4.2.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->fairlib) (1.15.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->fairlib) (2022.1)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fairlib) (1.4.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fairlib) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fairlib) (1.1.0)
Collecting sacremoses
Downloading sacremoses-0.0.53.tar.gz (880 kB)
|████████████████████████████████| 880 kB 51.9 MB/s
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (3.6.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (4.11.3)
Collecting huggingface-hub<1.0,>=0.1.0
Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
|████████████████████████████████| 77 kB 6.9 MB/s
Collecting PyYAML
Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
|████████████████████████████████| 596 kB 60.7 MB/s
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (2019.12.20)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
|████████████████████████████████| 6.6 MB 54.8 MB/s
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (21.3)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers->fairlib) (3.8.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (3.0.4)
Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers->fairlib) (7.1.2)
Building wheels for collected packages: sacremoses
Building wheel for sacremoses (setup.py) ... done
Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=c5c48dec8f73ec0d60ebce7e2289c19af0e22809669612f2359f236a871759ad
Stored in directory: /root/.cache/pip/wheels/87/39/dd/a83eeef36d0bf98e7a4d1933a4ad2d660295a40613079bafc9
Successfully built sacremoses
Installing collected packages: PyYAML, tokenizers, sacremoses, huggingface-hub, transformers, pickle5, fairlib
Attempting uninstall: PyYAML
Found existing installation: PyYAML 3.13
Uninstalling PyYAML-3.13:
Successfully uninstalled PyYAML-3.13
Successfully installed PyYAML-6.0 fairlib-0.0.3 huggingface-hub-0.5.1 pickle5-0.0.12 sacremoses-0.0.53 tokenizers-0.12.1 transformers-4.18.0
[2]:
import fairlib
[3]:
!mkdir -p data/deepmoji
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_pos.npy' -P 'data/deepmoji'
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_neg.npy' -P 'data/deepmoji'
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_pos.npy' -P 'data/deepmoji'
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_neg.npy' -P 'data/deepmoji'
--2022-05-07 15:30:11-- https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_pos.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.127.128, 172.217.218.128, 142.251.18.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.127.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405494864 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/pos_pos.npy’
pos_pos.npy 100%[===================>] 386.71M 139MB/s in 2.8s
2022-05-07 15:30:14 (139 MB/s) - ‘data/deepmoji/pos_pos.npy’ saved [405494864/405494864]
--2022-05-07 15:30:14-- https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_neg.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.69.128, 108.177.127.128, 172.217.218.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.69.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405504080 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/pos_neg.npy’
pos_neg.npy 100%[===================>] 386.72M 173MB/s in 2.2s
2022-05-07 15:30:16 (173 MB/s) - ‘data/deepmoji/pos_neg.npy’ saved [405504080/405504080]
--2022-05-07 15:30:16-- https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_pos.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.18.128, 142.250.153.128, 74.125.128.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.18.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405494864 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/neg_pos.npy’
neg_pos.npy 100%[===================>] 386.71M 136MB/s in 2.9s
2022-05-07 15:30:19 (136 MB/s) - ‘data/deepmoji/neg_pos.npy’ saved [405494864/405494864]
--2022-05-07 15:30:19-- https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_neg.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.128.128, 173.194.69.128, 108.177.119.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.128.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405504080 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/neg_neg.npy’
neg_neg.npy 100%[===================>] 386.72M 142MB/s in 2.7s
2022-05-07 15:30:22 (142 MB/s) - ‘data/deepmoji/neg_neg.npy’ saved [405504080/405504080]
[4]:
fairlib.utils.seed_everything(2022)
import numpy as np
import os
def read_data_file(input_file: str):
vecs = np.load(input_file)
np.random.shuffle(vecs)
return vecs[:40000], vecs[40000:42000], vecs[42000:44000]
in_dir = "data/deepmoji"
out_dir = "data/deepmoji"
os.makedirs(out_dir, exist_ok=True)
for split in ['pos_pos', 'pos_neg', 'neg_pos', 'neg_neg']:
train, dev, test = read_data_file(in_dir + '/' + split + '.npy')
for split_dir, data in zip(['train', 'dev', 'test'], [train, dev, test]):
os.makedirs(out_dir + '/' + split_dir, exist_ok=True)
np.save(out_dir + '/' + split_dir + '/' + split + '.npy', data)
Train a Model
[5]:
Shared_options = {
# The name of the dataset, corresponding dataloader will be used,
"dataset": "Moji",
# Specifiy the path to the input data
"data_dir": "data/deepmoji",
# Device for computing, -1 is the cpu
"device_id": -1,
# The default path for saving experimental results
"results_dir": r"results",
# Will be used for saving experimental results
"project_dir": r"dev",
# We will focusing on TPR GAP, implying the Equalized Odds for binary classification.
"GAP_metric_name": "TPR_GAP",
# The overall performance will be measured as accuracy
"Performance_metric_name": "accuracy",
# Model selections are based on DTO
"selection_criterion": "DTO",
# Default dirs for saving checkpoints
"checkpoint_dir": "models",
"checkpoint_name": "checkpoint_epoch",
# Loading experimental results
"n_jobs": 1,
}
[6]:
args = {
"dataset":Shared_options["dataset"],
"data_dir":Shared_options["data_dir"],
"device_id":Shared_options["device_id"],
# Give a name to the exp, which will be used in the path
"exp_id":"vanilla",
}
# Init the argument
options = fairlib.BaseOptions()
state = options.get_state(args=args, silence=True)
fairlib.utils.seed_everything(2022)
# Init Model
model = fairlib.networks.get_main_model(state)
INFO:root:Unexpected args: ['-f', '/root/.local/share/jupyter/runtime/kernel-3dad1bfd-dd35-40d7-b985-35feaaff967a.json']
INFO:root:Logging to ./results/dev/Moji/vanilla/output.log
2022-05-07 15:30:29 [INFO ] ======================================== 2022-05-07 15:30:29 ========================================
2022-05-07 15:30:29 [INFO ] Base directory is ./results/dev/Moji/vanilla
Loaded data shapes: (99998, 2304), (99998,), (99998,)
Loaded data shapes: (8000, 2304), (8000,), (8000,)
Loaded data shapes: (7998, 2304), (7998,), (7998,)
2022-05-07 15:30:30 [INFO ] MLP(
2022-05-07 15:30:30 [INFO ] (output_layer): Linear(in_features=300, out_features=2, bias=True)
2022-05-07 15:30:30 [INFO ] (AF): Tanh()
2022-05-07 15:30:30 [INFO ] (hidden_layers): ModuleList(
2022-05-07 15:30:30 [INFO ] (0): Linear(in_features=2304, out_features=300, bias=True)
2022-05-07 15:30:30 [INFO ] (1): Tanh()
2022-05-07 15:30:30 [INFO ] (2): Linear(in_features=300, out_features=300, bias=True)
2022-05-07 15:30:30 [INFO ] (3): Tanh()
2022-05-07 15:30:30 [INFO ] )
2022-05-07 15:30:30 [INFO ] (criterion): CrossEntropyLoss()
2022-05-07 15:30:30 [INFO ] )
2022-05-07 15:30:30 [INFO ] Total number of parameters: 782402
[7]:
model.train_self()
2022-05-07 15:30:30 [INFO ] Epoch: 0 [ 0/ 99998 ( 0%)] Loss: 0.6906 Data Time: 0.02s Train Time: 0.19s
2022-05-07 15:30:34 [INFO ] Epoch: 0 [ 51200/ 99998 (51%)] Loss: 0.3926 Data Time: 0.36s Train Time: 3.30s
2022-05-07 15:30:38 [INFO ] Evaluation at Epoch 0
2022-05-07 15:30:38 [INFO ] Validation accuracy: 72.55 macro_fscore: 72.44 micro_fscore: 72.55 TPR_GAP: 40.07 FPR_GAP: 40.07 PPR_GAP: 39.10
2022-05-07 15:30:38 [INFO ] Test accuracy: 71.41 macro_fscore: 71.30 micro_fscore: 71.41 TPR_GAP: 39.01 FPR_GAP: 39.01 PPR_GAP: 37.84
2022-05-07 15:30:38 [INFO ] Epoch: 1 [ 0/ 99998 ( 0%)] Loss: 0.4105 Data Time: 0.01s Train Time: 0.07s
2022-05-07 15:30:42 [INFO ] Epoch: 1 [ 51200/ 99998 (51%)] Loss: 0.4156 Data Time: 0.37s Train Time: 3.24s
2022-05-07 15:30:46 [INFO ] Evaluation at Epoch 1
2022-05-07 15:30:46 [INFO ] Validation accuracy: 72.36 macro_fscore: 72.32 micro_fscore: 72.36 TPR_GAP: 39.81 FPR_GAP: 39.81 PPR_GAP: 39.27
2022-05-07 15:30:46 [INFO ] Test accuracy: 71.01 macro_fscore: 70.98 micro_fscore: 71.01 TPR_GAP: 39.40 FPR_GAP: 39.40 PPR_GAP: 38.64
2022-05-07 15:30:46 [INFO ] Epoch: 2 [ 0/ 99998 ( 0%)] Loss: 0.3433 Data Time: 0.01s Train Time: 0.07s
2022-05-07 15:30:49 [INFO ] Epoch: 2 [ 51200/ 99998 (51%)] Loss: 0.3734 Data Time: 0.38s Train Time: 3.25s
2022-05-07 15:30:53 [INFO ] Epochs since last improvement: 1
2022-05-07 15:30:53 [INFO ] Evaluation at Epoch 2
2022-05-07 15:30:53 [INFO ] Validation accuracy: 72.42 macro_fscore: 72.37 micro_fscore: 72.42 TPR_GAP: 40.91 FPR_GAP: 40.91 PPR_GAP: 40.20
2022-05-07 15:30:53 [INFO ] Test accuracy: 70.98 macro_fscore: 70.93 micro_fscore: 70.98 TPR_GAP: 40.21 FPR_GAP: 40.21 PPR_GAP: 39.39
2022-05-07 15:30:53 [INFO ] Epoch: 3 [ 0/ 99998 ( 0%)] Loss: 0.3773 Data Time: 0.01s Train Time: 0.06s
2022-05-07 15:30:57 [INFO ] Epoch: 3 [ 51200/ 99998 (51%)] Loss: 0.3479 Data Time: 0.38s Train Time: 3.25s
2022-05-07 15:31:01 [INFO ] Epochs since last improvement: 2
2022-05-07 15:31:01 [INFO ] Evaluation at Epoch 3
2022-05-07 15:31:01 [INFO ] Validation accuracy: 72.09 macro_fscore: 71.92 micro_fscore: 72.09 TPR_GAP: 41.54 FPR_GAP: 41.54 PPR_GAP: 40.17
2022-05-07 15:31:01 [INFO ] Test accuracy: 71.17 macro_fscore: 71.02 micro_fscore: 71.17 TPR_GAP: 40.32 FPR_GAP: 40.32 PPR_GAP: 38.96
2022-05-07 15:31:01 [INFO ] Epoch: 4 [ 0/ 99998 ( 0%)] Loss: 0.3839 Data Time: 0.02s Train Time: 0.06s
2022-05-07 15:31:05 [INFO ] Epoch: 4 [ 51200/ 99998 (51%)] Loss: 0.3499 Data Time: 0.38s Train Time: 3.28s
2022-05-07 15:31:11 [INFO ] Epochs since last improvement: 3
2022-05-07 15:31:11 [INFO ] Evaluation at Epoch 4
2022-05-07 15:31:11 [INFO ] Validation accuracy: 71.50 macro_fscore: 71.43 micro_fscore: 71.50 TPR_GAP: 42.76 FPR_GAP: 42.76 PPR_GAP: 42.00
2022-05-07 15:31:11 [INFO ] Test accuracy: 70.49 macro_fscore: 70.43 micro_fscore: 70.49 TPR_GAP: 41.37 FPR_GAP: 41.37 PPR_GAP: 40.51
2022-05-07 15:31:11 [INFO ] Epoch: 5 [ 0/ 99998 ( 0%)] Loss: 0.3746 Data Time: 0.03s Train Time: 0.28s
2022-05-07 15:31:15 [INFO ] Epoch: 5 [ 51200/ 99998 (51%)] Loss: 0.3748 Data Time: 0.37s Train Time: 3.28s
2022-05-07 15:31:19 [INFO ] Epochs since last improvement: 4
2022-05-07 15:31:19 [INFO ] Evaluation at Epoch 5
2022-05-07 15:31:19 [INFO ] Validation accuracy: 72.67 macro_fscore: 72.60 micro_fscore: 72.67 TPR_GAP: 39.17 FPR_GAP: 39.17 PPR_GAP: 38.35
2022-05-07 15:31:19 [INFO ] Test accuracy: 71.69 macro_fscore: 71.62 micro_fscore: 71.69 TPR_GAP: 37.97 FPR_GAP: 37.97 PPR_GAP: 36.91
2022-05-07 15:31:19 [INFO ] Epoch: 6 [ 0/ 99998 ( 0%)] Loss: 0.3624 Data Time: 0.01s Train Time: 0.06s
2022-05-07 15:31:23 [INFO ] Epoch: 6 [ 51200/ 99998 (51%)] Loss: 0.3529 Data Time: 0.38s Train Time: 3.24s
2022-05-07 15:31:26 [INFO ] Epochs since last improvement: 5
2022-05-07 15:31:27 [INFO ] Evaluation at Epoch 6
2022-05-07 15:31:27 [INFO ] Validation accuracy: 72.70 macro_fscore: 72.62 micro_fscore: 72.70 TPR_GAP: 38.29 FPR_GAP: 38.29 PPR_GAP: 37.50
2022-05-07 15:31:27 [INFO ] Test accuracy: 71.76 macro_fscore: 71.70 micro_fscore: 71.76 TPR_GAP: 37.59 FPR_GAP: 37.59 PPR_GAP: 36.79
By default, fairlib print and save 6 metrics: - accuracy, macro F1 score, and micro F1 score, which are most commenly used evaluation metrics for performance evaluation. - rms aggregated TPR, FPR, and PPR GAP scores for fairness assesment.
Scenario 1: Confusion Matrix Based Metrics
[8]:
import torch
path = "{results_dir}/{project_dir}/{dataset}/{exp_id}/{checkpoint_dir}/{checkpoint_name}{epoch}.pth.tar"
# Path to the first epoch
path_vanilla_epoch0 = path.format(
exp_id = "vanilla",
epoch = "0",
results_dir=Shared_options["results_dir"],
project_dir=Shared_options["project_dir"],
dataset=Shared_options["dataset"],
checkpoint_dir=Shared_options["checkpoint_dir"],
checkpoint_name=Shared_options["checkpoint_name"],
)
epoch_results = torch.load(path_vanilla_epoch0)
# The keys for saved items
print(epoch_results.keys())
dict_keys(['epoch', 'epochs_since_improvement', 'loss', 'valid_confusion_matrices', 'test_confusion_matrices', 'dev_evaluations', 'test_evaluations'])
fairlib saves confusion matrices for each protected groups as well as the overall confusion matrix. These matrices are stored in a dictionary, indexed with the group id.
[9]:
epoch_results["valid_confusion_matrices"].keys()
[9]:
dict_keys(['overall', 0, 1])
[10]:
epoch_results["valid_confusion_matrices"]["overall"]
[10]:
array([[2655, 1345],
[ 851, 3149]])
[11]:
from fairlib.src.evaluators.evaluator import confusion_matrix_based_scores
[12]:
confusion_matrix_based_scores(epoch_results["valid_confusion_matrices"]["overall"])
[12]:
{'ACC': array([0.7255, 0.7255]),
'FDR': array([0.24272676, 0.29928794]),
'FNR': array([0.33625, 0.21275]),
'FPR': array([0.21275, 0.33625]),
'NPV': array([0.70071206, 0.75727324]),
'PPR': array([0.43825, 0.56175]),
'PPV': array([0.75727324, 0.70071206]),
'TNR': array([0.78725, 0.66375]),
'TPR': array([0.66375, 0.78725])}
[13]:
from fairlib.src.evaluators.evaluator import power_mean
[14]:
numbers = np.array([1,2,3,4,5])
# generalized mean aggregation
[
power_mean(numbers, p=100), # Max
power_mean(numbers, p=2), # Root Mean Square
power_mean(numbers, p=1), # Arithmetic Mean
power_mean(numbers, p=-100), # Min
]
[14]:
[5, 3.3166247903554, 3.0, 1]
Max Violation
RMS GAP
Max Min Fairness
[15]:
from fairlib.src.evaluators.evaluator import Aggregation_GAP
[43]:
# Confusion matrices of the vanilla model's first epoch over the vailidation dataset
confusion_matrices = epoch_results["valid_confusion_matrices"]
# all_scores = dict()
all_scores = {}
# Overall evaluation
all_scores["overall"] = confusion_matrix_based_scores(confusion_matrices["overall"])
# Group scores
distinct_groups = [0,1] # binary protected groups, AAE verse SAE
for gid in distinct_groups:
group_confusion_matrix = confusion_matrices[gid]
all_scores[gid] = confusion_matrix_based_scores(group_confusion_matrix)
[68]:
Aggregation_GAP(
distinct_groups=[0,1],
all_scores=all_scores,
# Take the absolute different if None,
# using generalized mean aggregation if not None.
group_agg_power = -10,
# RMS aggregation by default
class_agg_power=2,
metric="TPR")
[68]:
0.2003354978462722