Adding Customized Evaluation Metrics

Installation

!pip install fairlib
Collecting fairlib
  Downloading fairlib-0.0.3-py3-none-any.whl (63 kB)
     |████████████████████████████████| 63 kB 2.0 MB/s 
[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.21.6)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.0.2)
Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from fairlib) (0.11.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from fairlib) (4.64.0)
Requirement already satisfied: docopt in /usr/local/lib/python3.7/dist-packages (from fairlib) (0.6.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.3.5)
Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from fairlib) (3.13)
Collecting pickle5
  Downloading pickle5-0.0.12-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (256 kB)
     |████████████████████████████████| 256 kB 29.1 MB/s 
[?25hRequirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from fairlib) (3.2.2)
Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
     |████████████████████████████████| 4.0 MB 55.2 MB/s 
[?25hRequirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from fairlib) (1.11.0+cu113)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (1.4.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (0.11.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fairlib) (3.0.8)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->fairlib) (4.2.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->fairlib) (1.15.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->fairlib) (2022.1)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fairlib) (1.4.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fairlib) (3.1.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fairlib) (1.1.0)
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
     |████████████████████████████████| 880 kB 51.9 MB/s 
[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (3.6.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (4.11.3)
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
     |████████████████████████████████| 77 kB 6.9 MB/s 
[?25hCollecting PyYAML
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
     |████████████████████████████████| 596 kB 60.7 MB/s 
[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (2019.12.20)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
     |████████████████████████████████| 6.6 MB 54.8 MB/s 
[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers->fairlib) (21.3)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers->fairlib) (3.8.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers->fairlib) (3.0.4)
Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers->fairlib) (7.1.2)
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=c5c48dec8f73ec0d60ebce7e2289c19af0e22809669612f2359f236a871759ad
  Stored in directory: /root/.cache/pip/wheels/87/39/dd/a83eeef36d0bf98e7a4d1933a4ad2d660295a40613079bafc9
Successfully built sacremoses
Installing collected packages: PyYAML, tokenizers, sacremoses, huggingface-hub, transformers, pickle5, fairlib
  Attempting uninstall: PyYAML
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed PyYAML-6.0 fairlib-0.0.3 huggingface-hub-0.5.1 pickle5-0.0.12 sacremoses-0.0.53 tokenizers-0.12.1 transformers-4.18.0
import fairlib
!mkdir -p data/deepmoji
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_pos.npy' -P 'data/deepmoji'
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_neg.npy' -P 'data/deepmoji'
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_pos.npy' -P 'data/deepmoji'
!wget 'https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_neg.npy' -P 'data/deepmoji'
--2022-05-07 15:30:11--  https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_pos.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.127.128, 172.217.218.128, 142.251.18.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.127.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405494864 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/pos_pos.npy’

pos_pos.npy         100%[===================>] 386.71M   139MB/s    in 2.8s    

2022-05-07 15:30:14 (139 MB/s) - ‘data/deepmoji/pos_pos.npy’ saved [405494864/405494864]

--2022-05-07 15:30:14--  https://storage.googleapis.com/ai2i/nullspace/deepmoji/pos_neg.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.69.128, 108.177.127.128, 172.217.218.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.69.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405504080 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/pos_neg.npy’

pos_neg.npy         100%[===================>] 386.72M   173MB/s    in 2.2s    

2022-05-07 15:30:16 (173 MB/s) - ‘data/deepmoji/pos_neg.npy’ saved [405504080/405504080]

--2022-05-07 15:30:16--  https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_pos.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.18.128, 142.250.153.128, 74.125.128.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.18.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405494864 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/neg_pos.npy’

neg_pos.npy         100%[===================>] 386.71M   136MB/s    in 2.9s    

2022-05-07 15:30:19 (136 MB/s) - ‘data/deepmoji/neg_pos.npy’ saved [405494864/405494864]

--2022-05-07 15:30:19--  https://storage.googleapis.com/ai2i/nullspace/deepmoji/neg_neg.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.128.128, 173.194.69.128, 108.177.119.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.128.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405504080 (387M) [application/octet-stream]
Saving to: ‘data/deepmoji/neg_neg.npy’

neg_neg.npy         100%[===================>] 386.72M   142MB/s    in 2.7s    

2022-05-07 15:30:22 (142 MB/s) - ‘data/deepmoji/neg_neg.npy’ saved [405504080/405504080]
fairlib.utils.seed_everything(2022)

import numpy as np
import os

def read_data_file(input_file: str):
    vecs = np.load(input_file)

    np.random.shuffle(vecs)

    return vecs[:40000], vecs[40000:42000], vecs[42000:44000]

in_dir = "data/deepmoji"
out_dir = "data/deepmoji"

os.makedirs(out_dir, exist_ok=True)

for split in ['pos_pos', 'pos_neg', 'neg_pos', 'neg_neg']:
    train, dev, test = read_data_file(in_dir + '/' + split + '.npy')
    for split_dir, data in zip(['train', 'dev', 'test'], [train, dev, test]):
        os.makedirs(out_dir + '/' + split_dir, exist_ok=True)
        np.save(out_dir + '/' + split_dir + '/' + split + '.npy', data)

Train a Model

Shared_options = {
    # The name of the dataset, corresponding dataloader will be used,
    "dataset":  "Moji",

    # Specifiy the path to the input data
    "data_dir": "data/deepmoji",

    # Device for computing, -1 is the cpu
    "device_id":    -1,

    # The default path for saving experimental results
    "results_dir":  r"results",

    # Will be used for saving experimental results
    "project_dir":  r"dev",

    # We will focusing on TPR GAP, implying the Equalized Odds for binary classification.
    "GAP_metric_name":  "TPR_GAP",

    # The overall performance will be measured as accuracy
    "Performance_metric_name":  "accuracy",

    # Model selections are based on DTO
    "selection_criterion":  "DTO",

    # Default dirs for saving checkpoints
    "checkpoint_dir":   "models",
    "checkpoint_name":  "checkpoint_epoch",

    # Loading experimental results
    "n_jobs":   1,
}
args = {
    "dataset":Shared_options["dataset"], 
    "data_dir":Shared_options["data_dir"],
    "device_id":Shared_options["device_id"],

    # Give a name to the exp, which will be used in the path
    "exp_id":"vanilla",
}

# Init the argument
options = fairlib.BaseOptions()
state = options.get_state(args=args, silence=True)

fairlib.utils.seed_everything(2022)

# Init Model
model = fairlib.networks.get_main_model(state)
INFO:root:Unexpected args: ['-f', '/root/.local/share/jupyter/runtime/kernel-3dad1bfd-dd35-40d7-b985-35feaaff967a.json']
INFO:root:Logging to ./results/dev/Moji/vanilla/output.log


2022-05-07 15:30:29 [INFO ]  ======================================== 2022-05-07 15:30:29 ========================================
2022-05-07 15:30:29 [INFO ]  Base directory is ./results/dev/Moji/vanilla
Loaded data shapes: (99998, 2304), (99998,), (99998,)
Loaded data shapes: (8000, 2304), (8000,), (8000,)
Loaded data shapes: (7998, 2304), (7998,), (7998,)
2022-05-07 15:30:30 [INFO ]  MLP( 
2022-05-07 15:30:30 [INFO ]    (output_layer): Linear(in_features=300, out_features=2, bias=True)
2022-05-07 15:30:30 [INFO ]    (AF): Tanh()
2022-05-07 15:30:30 [INFO ]    (hidden_layers): ModuleList(
2022-05-07 15:30:30 [INFO ]      (0): Linear(in_features=2304, out_features=300, bias=True)
2022-05-07 15:30:30 [INFO ]      (1): Tanh()
2022-05-07 15:30:30 [INFO ]      (2): Linear(in_features=300, out_features=300, bias=True)
2022-05-07 15:30:30 [INFO ]      (3): Tanh()
2022-05-07 15:30:30 [INFO ]    )
2022-05-07 15:30:30 [INFO ]    (criterion): CrossEntropyLoss()
2022-05-07 15:30:30 [INFO ]  )
2022-05-07 15:30:30 [INFO ]  Total number of parameters: 782402 
model.train_self()
2022-05-07 15:30:30 [INFO ]  Epoch:    0 [      0/  99998 ( 0%)]	Loss: 0.6906	 Data Time: 0.02s	Train Time: 0.19s
2022-05-07 15:30:34 [INFO ]  Epoch:    0 [  51200/  99998 (51%)]	Loss: 0.3926	 Data Time: 0.36s	Train Time: 3.30s
2022-05-07 15:30:38 [INFO ]  Evaluation at Epoch 0
2022-05-07 15:30:38 [INFO ]  Validation accuracy: 72.55	macro_fscore: 72.44	micro_fscore: 72.55	TPR_GAP: 40.07	FPR_GAP: 40.07	PPR_GAP: 39.10	
2022-05-07 15:30:38 [INFO ]  Test accuracy: 71.41	macro_fscore: 71.30	micro_fscore: 71.41	TPR_GAP: 39.01	FPR_GAP: 39.01	PPR_GAP: 37.84	
2022-05-07 15:30:38 [INFO ]  Epoch:    1 [      0/  99998 ( 0%)]	Loss: 0.4105	 Data Time: 0.01s	Train Time: 0.07s
2022-05-07 15:30:42 [INFO ]  Epoch:    1 [  51200/  99998 (51%)]	Loss: 0.4156	 Data Time: 0.37s	Train Time: 3.24s
2022-05-07 15:30:46 [INFO ]  Evaluation at Epoch 1
2022-05-07 15:30:46 [INFO ]  Validation accuracy: 72.36	macro_fscore: 72.32	micro_fscore: 72.36	TPR_GAP: 39.81	FPR_GAP: 39.81	PPR_GAP: 39.27	
2022-05-07 15:30:46 [INFO ]  Test accuracy: 71.01	macro_fscore: 70.98	micro_fscore: 71.01	TPR_GAP: 39.40	FPR_GAP: 39.40	PPR_GAP: 38.64	
2022-05-07 15:30:46 [INFO ]  Epoch:    2 [      0/  99998 ( 0%)]	Loss: 0.3433	 Data Time: 0.01s	Train Time: 0.07s
2022-05-07 15:30:49 [INFO ]  Epoch:    2 [  51200/  99998 (51%)]	Loss: 0.3734	 Data Time: 0.38s	Train Time: 3.25s
2022-05-07 15:30:53 [INFO ]  Epochs since last improvement: 1
2022-05-07 15:30:53 [INFO ]  Evaluation at Epoch 2
2022-05-07 15:30:53 [INFO ]  Validation accuracy: 72.42	macro_fscore: 72.37	micro_fscore: 72.42	TPR_GAP: 40.91	FPR_GAP: 40.91	PPR_GAP: 40.20	
2022-05-07 15:30:53 [INFO ]  Test accuracy: 70.98	macro_fscore: 70.93	micro_fscore: 70.98	TPR_GAP: 40.21	FPR_GAP: 40.21	PPR_GAP: 39.39	
2022-05-07 15:30:53 [INFO ]  Epoch:    3 [      0/  99998 ( 0%)]	Loss: 0.3773	 Data Time: 0.01s	Train Time: 0.06s
2022-05-07 15:30:57 [INFO ]  Epoch:    3 [  51200/  99998 (51%)]	Loss: 0.3479	 Data Time: 0.38s	Train Time: 3.25s
2022-05-07 15:31:01 [INFO ]  Epochs since last improvement: 2
2022-05-07 15:31:01 [INFO ]  Evaluation at Epoch 3
2022-05-07 15:31:01 [INFO ]  Validation accuracy: 72.09	macro_fscore: 71.92	micro_fscore: 72.09	TPR_GAP: 41.54	FPR_GAP: 41.54	PPR_GAP: 40.17	
2022-05-07 15:31:01 [INFO ]  Test accuracy: 71.17	macro_fscore: 71.02	micro_fscore: 71.17	TPR_GAP: 40.32	FPR_GAP: 40.32	PPR_GAP: 38.96	
2022-05-07 15:31:01 [INFO ]  Epoch:    4 [      0/  99998 ( 0%)]	Loss: 0.3839	 Data Time: 0.02s	Train Time: 0.06s
2022-05-07 15:31:05 [INFO ]  Epoch:    4 [  51200/  99998 (51%)]	Loss: 0.3499	 Data Time: 0.38s	Train Time: 3.28s
2022-05-07 15:31:11 [INFO ]  Epochs since last improvement: 3
2022-05-07 15:31:11 [INFO ]  Evaluation at Epoch 4
2022-05-07 15:31:11 [INFO ]  Validation accuracy: 71.50	macro_fscore: 71.43	micro_fscore: 71.50	TPR_GAP: 42.76	FPR_GAP: 42.76	PPR_GAP: 42.00	
2022-05-07 15:31:11 [INFO ]  Test accuracy: 70.49	macro_fscore: 70.43	micro_fscore: 70.49	TPR_GAP: 41.37	FPR_GAP: 41.37	PPR_GAP: 40.51	
2022-05-07 15:31:11 [INFO ]  Epoch:    5 [      0/  99998 ( 0%)]	Loss: 0.3746	 Data Time: 0.03s	Train Time: 0.28s
2022-05-07 15:31:15 [INFO ]  Epoch:    5 [  51200/  99998 (51%)]	Loss: 0.3748	 Data Time: 0.37s	Train Time: 3.28s
2022-05-07 15:31:19 [INFO ]  Epochs since last improvement: 4
2022-05-07 15:31:19 [INFO ]  Evaluation at Epoch 5
2022-05-07 15:31:19 [INFO ]  Validation accuracy: 72.67	macro_fscore: 72.60	micro_fscore: 72.67	TPR_GAP: 39.17	FPR_GAP: 39.17	PPR_GAP: 38.35	
2022-05-07 15:31:19 [INFO ]  Test accuracy: 71.69	macro_fscore: 71.62	micro_fscore: 71.69	TPR_GAP: 37.97	FPR_GAP: 37.97	PPR_GAP: 36.91	
2022-05-07 15:31:19 [INFO ]  Epoch:    6 [      0/  99998 ( 0%)]	Loss: 0.3624	 Data Time: 0.01s	Train Time: 0.06s
2022-05-07 15:31:23 [INFO ]  Epoch:    6 [  51200/  99998 (51%)]	Loss: 0.3529	 Data Time: 0.38s	Train Time: 3.24s
2022-05-07 15:31:26 [INFO ]  Epochs since last improvement: 5
2022-05-07 15:31:27 [INFO ]  Evaluation at Epoch 6
2022-05-07 15:31:27 [INFO ]  Validation accuracy: 72.70	macro_fscore: 72.62	micro_fscore: 72.70	TPR_GAP: 38.29	FPR_GAP: 38.29	PPR_GAP: 37.50	
2022-05-07 15:31:27 [INFO ]  Test accuracy: 71.76	macro_fscore: 71.70	micro_fscore: 71.76	TPR_GAP: 37.59	FPR_GAP: 37.59	PPR_GAP: 36.79	

By default, fairlib print and save 6 metrics:

  • accuracy, macro F1 score, and micro F1 score, which are most commenly used evaluation metrics for performance evaluation.

  • rms aggregated TPR, FPR, and PPR GAP scores for fairness assesment.

Confusion Matrix Based Metrics

import torch

path = "{results_dir}/{project_dir}/{dataset}/{exp_id}/{checkpoint_dir}/{checkpoint_name}{epoch}.pth.tar"

# Path to the first epoch
path_vanilla_epoch0 = path.format(
    exp_id = "vanilla",
    epoch = "0",
    results_dir=Shared_options["results_dir"],
    project_dir=Shared_options["project_dir"],
    dataset=Shared_options["dataset"],
    checkpoint_dir=Shared_options["checkpoint_dir"],
    checkpoint_name=Shared_options["checkpoint_name"],
)

epoch_results = torch.load(path_vanilla_epoch0)
# The keys for saved items
print(epoch_results.keys())
dict_keys(['epoch', 'epochs_since_improvement', 'loss', 'valid_confusion_matrices', 'test_confusion_matrices', 'dev_evaluations', 'test_evaluations'])

fairlib saves confusion matrices for each protected groups as well as the overall confusion matrix. These matrices are stored in a dictionary, indexed with the group id.

epoch_results["valid_confusion_matrices"].keys()
dict_keys(['overall', 0, 1])
epoch_results["valid_confusion_matrices"]["overall"]
array([[2655, 1345],
       [ 851, 3149]])
from fairlib.src.evaluators.evaluator import confusion_matrix_based_scores
confusion_matrix_based_scores(epoch_results["valid_confusion_matrices"]["overall"])
{'ACC': array([0.7255, 0.7255]),
 'FDR': array([0.24272676, 0.29928794]),
 'FNR': array([0.33625, 0.21275]),
 'FPR': array([0.21275, 0.33625]),
 'NPV': array([0.70071206, 0.75727324]),
 'PPR': array([0.43825, 0.56175]),
 'PPV': array([0.75727324, 0.70071206]),
 'TNR': array([0.78725, 0.66375]),
 'TPR': array([0.66375, 0.78725])}
from fairlib.src.evaluators.evaluator import power_mean
numbers = np.array([1,2,3,4,5])
# generalized mean aggregation
[
 power_mean(numbers, p=100), # Max
 power_mean(numbers, p=2), # Root Mean Square
 power_mean(numbers, p=1), # Arithmetic Mean
 power_mean(numbers, p=-100), # Min
]
[5, 3.3166247903554, 3.0, 1]
  • Max Violation

  • RMS GAP

  • Max Min Fairness

from fairlib.src.evaluators.evaluator import Aggregation_GAP
# Confusion matrices of the vanilla model's first epoch over the vailidation dataset
confusion_matrices = epoch_results["valid_confusion_matrices"]

# all_scores = dict()
all_scores = {}

# Overall evaluation
all_scores["overall"] = confusion_matrix_based_scores(confusion_matrices["overall"])

# Group scores
distinct_groups = [0,1] # binary protected groups, AAE verse SAE
for gid in distinct_groups:
    group_confusion_matrix = confusion_matrices[gid]
    all_scores[gid] = confusion_matrix_based_scores(group_confusion_matrix)
Aggregation_GAP(
    distinct_groups=[0,1], 
    all_scores=all_scores, 

    # Take the absolute different if None, 
    # using generalized mean aggregation if not None.
    group_agg_power = -10,
    
    # RMS aggregation by default
    class_agg_power=2, 
    metric="TPR")
0.2003354978462722