Customized models and datasets for CV

Open In Colab

In this tutorial we will: - Show how to add a model for image classification - Show how to add a dataloader with image preprocessing

We will be using the colored MNIST dataset, please see the origianl demo for more details.

Installation

Again, the first step will be installing our libarary

[1]:
!pip install -q fairlib
     |████████████████████████████████| 85 kB 1.2 MB/s
     |████████████████████████████████| 256 kB 34.4 MB/s
     |████████████████████████████████| 880 kB 60.2 MB/s
     |████████████████████████████████| 4.4 MB 57.8 MB/s
     |████████████████████████████████| 101 kB 12.4 MB/s
     |████████████████████████████████| 596 kB 36.3 MB/s
     |████████████████████████████████| 6.6 MB 12.4 MB/s
  Building wheel for sacremoses (setup.py) ... done
[2]:
import fairlib

Explore Datasets

[3]:
import os

import numpy as np
from PIL import Image

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torchvision import transforms
from torchvision import datasets
import torchvision.datasets.utils as dataset_utils
[4]:
def color_grayscale_arr(arr, red=True):
    """Converts grayscale image to either red or green"""
    assert arr.ndim == 2
    dtype = arr.dtype
    h, w = arr.shape
    arr = np.reshape(arr, [h, w, 1])
    if red:
        arr = np.concatenate([
            arr,
            np.zeros((h, w, 2), dtype=dtype)], axis=2)
    else:
        arr = np.concatenate([
            np.zeros((h, w, 1), dtype=dtype),
            arr,
            np.zeros((h, w, 1), dtype=dtype)], axis=2)
    return arr
[5]:
root = "./data"
[6]:
train_mnist = datasets.mnist.MNIST(root=root, train=True, download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

[7]:
fig, ax = plt.subplots(1,3)
ax[0].imshow(train_mnist[0][0]) # Original
ax[1].imshow(color_grayscale_arr(np.array(train_mnist[0][0]))) # Red
ax[2].imshow(color_grayscale_arr(np.array(train_mnist[0][0]), False)) # Green
[7]:
<matplotlib.image.AxesImage at 0x7fbdcef158d0>
../_images/tutorial_Image_Inputs_9_1.png
[8]:
from fairlib import datasets
[9]:
datasets.prepare_dataset("coloredmnist", "coloredmnist")
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to coloredmnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting coloredmnist/MNIST/raw/train-images-idx3-ubyte.gz to coloredmnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting coloredmnist/MNIST/raw/train-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to coloredmnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting coloredmnist/MNIST/raw/t10k-images-idx3-ubyte.gz to coloredmnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting coloredmnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw

Converting image 0/50000
Converting image 5000/50000
Converting image 10000/50000
Converting image 15000/50000
Converting image 20000/50000
Converting image 25000/50000
Converting image 30000/50000
Converting image 35000/50000
Converting image 40000/50000
Converting image 45000/50000
Converting image 0/10000
Converting image 5000/10000
Converting image 0/10000
Converting image 5000/10000
[10]:
from fairlib import networks, BaseOptions, dataloaders
[11]:
Shared_options = {
    # The name of the dataset, correponding dataloader will be used,
    "dataset":  "MNIST",

    # Specifiy the path to the input data
    "data_dir": "./coloredmnist",

    # Device for computing, -1 is the cpu
    "device_id": -1,

    # The default path for saving experimental results
    "results_dir":  r"results",

    # The same as the dataset
    "project_dir":  r"dev",

    # We will focusing on TPR GAP, implying the Equalized Odds for binay classification.
    "GAP_metric_name":  "TPR_GAP",

    # The overall performance will be measured as accuracy
    "Performance_metric_name":  "accuracy",

    # Model selections are based on DTO
    "selection_criterion":  "DTO",

    # Default dirs for saving checkpoints
    "checkpoint_dir":   "models",
    "checkpoint_name":  "checkpoint_epoch",


    "n_jobs":   1,
}

Customizing NN Architectures

[12]:
class ConvNet(networks.utils.BaseModel):

    def __init__(self, args):
        super(ConvNet, self).__init__()
        self.args = args

        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)

        self.classifier = networks.classifier.MLP(args)

        self.init_for_training()

    def forward(self, input_data, group_label = None):
        x = input_data
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)

        return self.classifier(x, group_label)

    def hidden(self, input_data, group_label = None):
        x = input_data
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)

        return self.classifier.hidden(x, group_label)
[13]:
args = {
    "dataset":Shared_options["dataset"],
    "data_dir":Shared_options["data_dir"],
    "device_id":Shared_options["device_id"],

    # Give a name to the exp, which will be used in the path
    "exp_id":"vanilla",

    "emb_size": 4*4*50,
    "num_classes": 10,
    "num_groups": 2,
}

# Init the argument
options = BaseOptions()
state = options.get_state(args=args, silence=True)
INFO:root:Unexpected args: ['-f', '/root/.local/share/jupyter/runtime/kernel-a7e24054-f7e0-4930-9df2-bfcd40fedbad.json']
INFO:root:Logging to ./results/dev/MNIST/vanilla/output.log
2022-07-21 07:19:08 [INFO ]  ======================================== 2022-07-21 07:19:08 ========================================
2022-07-21 07:19:08 [INFO ]  Base directory is ./results/dev/MNIST/vanilla
2022-07-21 07:19:09 [INFO ]  Exception type : AssertionError
2022-07-21 07:19:09 [INFO ]  Exception message : Not implemented
2022-07-21 07:19:09 [INFO ]  Stack trace : ['File : /usr/local/lib/python3.7/dist-packages/fairlib/src/base_options.py , Line : 486, Func.Name : set_state, Message : train_iterator, dev_iterator, test_iterator = dataloaders.get_dataloaders(state)', 'File : /usr/local/lib/python3.7/dist-packages/fairlib/src/dataloaders/__init__.py , Line : 40, Func.Name : get_dataloaders, Message : ], "Not implemented"']
2022-07-21 07:19:09 [INFO ]  dataloaders need to be initialized!

Customizing Dataloader

[14]:
class CustomizedDataset(dataloaders.utils.BaseDataset):

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))])

    def load_data(self):

        self.data_dir = os.path.join(self.args.data_dir, "colored_MNIST_{}.pt".format(self.split))

        data = torch.load(self.data_dir)

        self.X = [self.transform(_img) for _img in data[0]]
        self.y = data[1]
        self.protected_label = data[2]
[15]:
customized_train_data = CustomizedDataset(args=state, split="train")
customized_dev_data = CustomizedDataset(args=state, split="dev")
customized_test_data = CustomizedDataset(args=state, split="test")

# DataLoader Parameters
tran_dataloader_params = {
        'batch_size': state.batch_size,
        'shuffle': True,
        'num_workers': state.num_workers}

eval_dataloader_params = {
        'batch_size': state.test_batch_size,
        'shuffle': False,
        'num_workers': state.num_workers}

# init dataloader
customized_training_generator = torch.utils.data.DataLoader(customized_train_data, **tran_dataloader_params)
customized_validation_generator = torch.utils.data.DataLoader(customized_dev_data, **eval_dataloader_params)
customized_test_generator = torch.utils.data.DataLoader(customized_test_data, **eval_dataloader_params)
Loaded data shapes: (50000,), (50000,), (50000,)
Loaded data shapes: (10000,), (10000,), (10000,)
Loaded data shapes: (10000,), (10000,), (10000,)

Training vanilla model without debiasing

[16]:
model = ConvNet(state)
2022-07-21 07:19:20 [INFO ]  MLP(
2022-07-21 07:19:20 [INFO ]    (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:19:20 [INFO ]    (AF): Tanh()
2022-07-21 07:19:20 [INFO ]    (hidden_layers): ModuleList(
2022-07-21 07:19:20 [INFO ]      (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ]      (1): Tanh()
2022-07-21 07:19:20 [INFO ]      (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ]      (3): Tanh()
2022-07-21 07:19:20 [INFO ]    )
2022-07-21 07:19:20 [INFO ]    (criterion): CrossEntropyLoss()
2022-07-21 07:19:20 [INFO ]  )
2022-07-21 07:19:20 [INFO ]  Total number of parameters: 333610

2022-07-21 07:19:20 [INFO ]  ConvNet(
2022-07-21 07:19:20 [INFO ]    (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:19:20 [INFO ]    (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:19:20 [INFO ]    (classifier): MLP(
2022-07-21 07:19:20 [INFO ]      (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:19:20 [INFO ]      (AF): Tanh()
2022-07-21 07:19:20 [INFO ]      (hidden_layers): ModuleList(
2022-07-21 07:19:20 [INFO ]        (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ]        (1): Tanh()
2022-07-21 07:19:20 [INFO ]        (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ]        (3): Tanh()
2022-07-21 07:19:20 [INFO ]      )
2022-07-21 07:19:20 [INFO ]      (criterion): CrossEntropyLoss()
2022-07-21 07:19:20 [INFO ]    )
2022-07-21 07:19:20 [INFO ]    (criterion): CrossEntropyLoss()
2022-07-21 07:19:20 [INFO ]  )
2022-07-21 07:19:20 [INFO ]  Total number of parameters: 360180

[17]:
model.train_self(
    train_generator = customized_training_generator,
    dev_generator = customized_validation_generator,
    test_generator = customized_test_generator,
)
2022-07-21 07:19:20 [INFO ]  Epoch:    0 [      0/  50000 ( 0%)]        Loss: 2.3093     Data Time: 0.01s       Train Time: 0.77s
2022-07-21 07:19:51 [INFO ]  Evaluation at Epoch 0
2022-07-21 07:19:53 [INFO ]  Validation accuracy: 95.53 macro_fscore: 95.52     micro_fscore: 95.53     TPR_GAP: 4.69   FPR_GAP: 0.40   PPR_GAP: 0.74
2022-07-21 07:19:53 [INFO ]  Test accuracy: 96.45       macro_fscore: 96.42     micro_fscore: 96.45     TPR_GAP: 3.58   FPR_GAP: 0.37   PPR_GAP: 0.72
2022-07-21 07:19:54 [INFO ]  Epoch:    1 [      0/  50000 ( 0%)]        Loss: 0.1037     Data Time: 0.01s       Train Time: 0.59s
2022-07-21 07:20:25 [INFO ]  Evaluation at Epoch 1
2022-07-21 07:20:28 [INFO ]  Validation accuracy: 97.29 macro_fscore: 97.29     micro_fscore: 97.29     TPR_GAP: 2.54   FPR_GAP: 0.27   PPR_GAP: 0.67
2022-07-21 07:20:28 [INFO ]  Test accuracy: 97.42       macro_fscore: 97.40     micro_fscore: 97.42     TPR_GAP: 3.33   FPR_GAP: 0.35   PPR_GAP: 0.73
2022-07-21 07:20:28 [INFO ]  Epoch:    2 [      0/  50000 ( 0%)]        Loss: 0.0530     Data Time: 0.01s       Train Time: 0.58s
2022-07-21 07:21:01 [INFO ]  Evaluation at Epoch 2
2022-07-21 07:21:04 [INFO ]  Validation accuracy: 97.48 macro_fscore: 97.48     micro_fscore: 97.48     TPR_GAP: 3.68   FPR_GAP: 0.39   PPR_GAP: 0.71
2022-07-21 07:21:04 [INFO ]  Test accuracy: 97.97       macro_fscore: 97.96     micro_fscore: 97.97     TPR_GAP: 3.30   FPR_GAP: 0.31   PPR_GAP: 0.98
2022-07-21 07:21:05 [INFO ]  Epoch:    3 [      0/  50000 ( 0%)]        Loss: 0.0368     Data Time: 0.01s       Train Time: 0.58s
2022-07-21 07:21:36 [INFO ]  Evaluation at Epoch 3
2022-07-21 07:21:38 [INFO ]  Validation accuracy: 98.31 macro_fscore: 98.31     micro_fscore: 98.31     TPR_GAP: 2.50   FPR_GAP: 0.30   PPR_GAP: 0.71
2022-07-21 07:21:38 [INFO ]  Test accuracy: 98.45       macro_fscore: 98.44     micro_fscore: 98.45     TPR_GAP: 2.05   FPR_GAP: 0.23   PPR_GAP: 0.86
2022-07-21 07:21:39 [INFO ]  Epoch:    4 [      0/  50000 ( 0%)]        Loss: 0.0100     Data Time: 0.01s       Train Time: 0.57s
2022-07-21 07:22:09 [INFO ]  Evaluation at Epoch 4
2022-07-21 07:22:12 [INFO ]  Validation accuracy: 98.27 macro_fscore: 98.27     micro_fscore: 98.27     TPR_GAP: 2.00   FPR_GAP: 0.18   PPR_GAP: 0.65
2022-07-21 07:22:12 [INFO ]  Test accuracy: 98.44       macro_fscore: 98.43     micro_fscore: 98.44     TPR_GAP: 1.79   FPR_GAP: 0.20   PPR_GAP: 0.81
2022-07-21 07:22:12 [INFO ]  Epoch:    5 [      0/  50000 ( 0%)]        Loss: 0.0130     Data Time: 0.01s       Train Time: 0.61s
2022-07-21 07:22:42 [INFO ]  Evaluation at Epoch 5
2022-07-21 07:22:45 [INFO ]  Validation accuracy: 98.46 macro_fscore: 98.46     micro_fscore: 98.46     TPR_GAP: 1.47   FPR_GAP: 0.22   PPR_GAP: 0.61
2022-07-21 07:22:45 [INFO ]  Test accuracy: 98.52       macro_fscore: 98.51     micro_fscore: 98.52     TPR_GAP: 1.47   FPR_GAP: 0.16   PPR_GAP: 0.83
2022-07-21 07:22:45 [INFO ]  Epoch:    6 [      0/  50000 ( 0%)]        Loss: 0.0104     Data Time: 0.01s       Train Time: 0.56s
2022-07-21 07:23:16 [INFO ]  Evaluation at Epoch 6
2022-07-21 07:23:18 [INFO ]  Validation accuracy: 98.57 macro_fscore: 98.57     micro_fscore: 98.57     TPR_GAP: 1.61   FPR_GAP: 0.19   PPR_GAP: 0.72
2022-07-21 07:23:18 [INFO ]  Test accuracy: 98.74       macro_fscore: 98.73     micro_fscore: 98.74     TPR_GAP: 1.45   FPR_GAP: 0.15   PPR_GAP: 0.80
2022-07-21 07:23:19 [INFO ]  Epoch:    7 [      0/  50000 ( 0%)]        Loss: 0.0070     Data Time: 0.01s       Train Time: 0.60s
2022-07-21 07:23:49 [INFO ]  Epochs since last improvement: 1
2022-07-21 07:23:49 [INFO ]  Evaluation at Epoch 7
2022-07-21 07:23:51 [INFO ]  Validation accuracy: 98.47 macro_fscore: 98.46     micro_fscore: 98.47     TPR_GAP: 1.83   FPR_GAP: 0.16   PPR_GAP: 0.68
2022-07-21 07:23:51 [INFO ]  Test accuracy: 98.55       macro_fscore: 98.53     micro_fscore: 98.55     TPR_GAP: 1.66   FPR_GAP: 0.16   PPR_GAP: 0.77
2022-07-21 07:23:52 [INFO ]  Epoch:    8 [      0/  50000 ( 0%)]        Loss: 0.0069     Data Time: 0.01s       Train Time: 0.57s
2022-07-21 07:24:22 [INFO ]  Epochs since last improvement: 2
2022-07-21 07:24:22 [INFO ]  Evaluation at Epoch 8
2022-07-21 07:24:24 [INFO ]  Validation accuracy: 98.56 macro_fscore: 98.55     micro_fscore: 98.56     TPR_GAP: 1.83   FPR_GAP: 0.14   PPR_GAP: 0.58
2022-07-21 07:24:24 [INFO ]  Test accuracy: 98.65       macro_fscore: 98.64     micro_fscore: 98.65     TPR_GAP: 1.80   FPR_GAP: 0.18   PPR_GAP: 0.82
2022-07-21 07:24:25 [INFO ]  Epoch:    9 [      0/  50000 ( 0%)]        Loss: 0.0015     Data Time: 0.01s       Train Time: 0.57s
2022-07-21 07:24:55 [INFO ]  Epochs since last improvement: 3
2022-07-21 07:24:55 [INFO ]  Evaluation at Epoch 9
2022-07-21 07:24:57 [INFO ]  Validation accuracy: 98.55 macro_fscore: 98.55     micro_fscore: 98.55     TPR_GAP: 1.65   FPR_GAP: 0.20   PPR_GAP: 0.66
2022-07-21 07:24:57 [INFO ]  Test accuracy: 98.67       macro_fscore: 98.66     micro_fscore: 98.67     TPR_GAP: 1.58   FPR_GAP: 0.17   PPR_GAP: 0.82
2022-07-21 07:24:58 [INFO ]  Epoch:   10 [      0/  50000 ( 0%)]        Loss: 0.0157     Data Time: 0.01s       Train Time: 0.56s
2022-07-21 07:25:28 [INFO ]  Epochs since last improvement: 4
2022-07-21 07:25:28 [INFO ]  Evaluation at Epoch 10
2022-07-21 07:25:30 [INFO ]  Validation accuracy: 98.41 macro_fscore: 98.39     micro_fscore: 98.41     TPR_GAP: 2.19   FPR_GAP: 0.24   PPR_GAP: 0.70
2022-07-21 07:25:30 [INFO ]  Test accuracy: 98.55       macro_fscore: 98.53     micro_fscore: 98.55     TPR_GAP: 1.47   FPR_GAP: 0.18   PPR_GAP: 0.82
2022-07-21 07:25:31 [INFO ]  Epoch:   11 [      0/  50000 ( 0%)]        Loss: 0.0038     Data Time: 0.01s       Train Time: 0.56s
2022-07-21 07:26:00 [INFO ]  Epochs since last improvement: 5
2022-07-21 07:26:00 [INFO ]  Evaluation at Epoch 11
2022-07-21 07:26:03 [INFO ]  Validation accuracy: 98.69 macro_fscore: 98.69     micro_fscore: 98.69     TPR_GAP: 1.32   FPR_GAP: 0.14   PPR_GAP: 0.64
2022-07-21 07:26:03 [INFO ]  Test accuracy: 98.84       macro_fscore: 98.83     micro_fscore: 98.84     TPR_GAP: 1.21   FPR_GAP: 0.13   PPR_GAP: 0.87

Improveing Fairness

[18]:
debiasing_args = {
    "dataset":Shared_options["dataset"],
    "data_dir":Shared_options["data_dir"],
    "device_id":Shared_options["device_id"],

    # Give a name to the exp, which will be used in the path
    "exp_id":"BT_Adv",

    "emb_size": 4*4*50,
    "num_classes": 10,
    "num_groups": 2,

    # Perform adversarial training if True
    "adv_debiasing":True,

    # Specify the hyperparameters for Balanced Training
    "BT":"Resampling",
    "BTObj":"EO",
}

# Init the argument
debias_options = BaseOptions()
debias_state = debias_options.get_state(args=debiasing_args, silence=True)

customized_train_data = CustomizedDataset(args=debias_state, split="train")
customized_dev_data = CustomizedDataset(args=debias_state, split="dev")
customized_test_data = CustomizedDataset(args=debias_state, split="test")

# DataLoader Parameters
tran_dataloader_params = {
        'batch_size': state.batch_size,
        'shuffle': True,
        'num_workers': state.num_workers}

eval_dataloader_params = {
        'batch_size': state.test_batch_size,
        'shuffle': False,
        'num_workers': state.num_workers}

# init dataloader
customized_training_generator = torch.utils.data.DataLoader(customized_train_data, **tran_dataloader_params)
customized_validation_generator = torch.utils.data.DataLoader(customized_dev_data, **eval_dataloader_params)
customized_test_generator = torch.utils.data.DataLoader(customized_test_data, **eval_dataloader_params)

debias_model = ConvNet(debias_state)
2022-07-21 07:26:03 [INFO ]  Unexpected args: ['-f', '/root/.local/share/jupyter/runtime/kernel-a7e24054-f7e0-4930-9df2-bfcd40fedbad.json']
2022-07-21 07:26:03 [INFO ]  Logging to ./results/dev/MNIST/BT_Adv/output.log
2022-07-21 07:26:03 [INFO ]  ======================================== 2022-07-21 07:26:03 ========================================
2022-07-21 07:26:03 [INFO ]  Base directory is ./results/dev/MNIST/BT_Adv
2022-07-21 07:26:03 [INFO ]  Exception type : AssertionError
2022-07-21 07:26:03 [INFO ]  Exception message : Not implemented
2022-07-21 07:26:03 [INFO ]  Stack trace : ['File : /usr/local/lib/python3.7/dist-packages/fairlib/src/base_options.py , Line : 486, Func.Name : set_state, Message : train_iterator, dev_iterator, test_iterator = dataloaders.get_dataloaders(state)', 'File : /usr/local/lib/python3.7/dist-packages/fairlib/src/dataloaders/__init__.py , Line : 40, Func.Name : get_dataloaders, Message : ], "Not implemented"']
2022-07-21 07:26:03 [INFO ]  dataloaders need to be initialized!
2022-07-21 07:26:03 [INFO ]  SubDiscriminator(
2022-07-21 07:26:03 [INFO ]    (grad_rev): GradientReversal()
2022-07-21 07:26:03 [INFO ]    (output_layer): Linear(in_features=300, out_features=2, bias=True)
2022-07-21 07:26:03 [INFO ]    (AF): ReLU()
2022-07-21 07:26:03 [INFO ]    (hidden_layers): ModuleList(
2022-07-21 07:26:03 [INFO ]      (0): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:03 [INFO ]      (1): ReLU()
2022-07-21 07:26:03 [INFO ]      (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:03 [INFO ]      (3): ReLU()
2022-07-21 07:26:03 [INFO ]    )
2022-07-21 07:26:03 [INFO ]    (criterion): CrossEntropyLoss()
2022-07-21 07:26:03 [INFO ]  )
2022-07-21 07:26:03 [INFO ]  Total number of parameters: 181202

2022-07-21 07:26:03 [INFO ]  Discriminator built!
Loaded data shapes: (49998,), (49998,), (49998,)
Loaded data shapes: (10000,), (10000,), (10000,)
Loaded data shapes: (10000,), (10000,), (10000,)
2022-07-21 07:26:14 [INFO ]  MLP(
2022-07-21 07:26:14 [INFO ]    (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:26:14 [INFO ]    (AF): Tanh()
2022-07-21 07:26:14 [INFO ]    (hidden_layers): ModuleList(
2022-07-21 07:26:14 [INFO ]      (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ]      (1): Tanh()
2022-07-21 07:26:14 [INFO ]      (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ]      (3): Tanh()
2022-07-21 07:26:14 [INFO ]    )
2022-07-21 07:26:14 [INFO ]    (criterion): CrossEntropyLoss()
2022-07-21 07:26:14 [INFO ]  )
2022-07-21 07:26:14 [INFO ]  Total number of parameters: 333610

2022-07-21 07:26:14 [INFO ]  ConvNet(
2022-07-21 07:26:14 [INFO ]    (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:26:14 [INFO ]    (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:26:14 [INFO ]    (classifier): MLP(
2022-07-21 07:26:14 [INFO ]      (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:26:14 [INFO ]      (AF): Tanh()
2022-07-21 07:26:14 [INFO ]      (hidden_layers): ModuleList(
2022-07-21 07:26:14 [INFO ]        (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ]        (1): Tanh()
2022-07-21 07:26:14 [INFO ]        (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ]        (3): Tanh()
2022-07-21 07:26:14 [INFO ]      )
2022-07-21 07:26:14 [INFO ]      (criterion): CrossEntropyLoss()
2022-07-21 07:26:14 [INFO ]    )
2022-07-21 07:26:14 [INFO ]    (criterion): CrossEntropyLoss()
2022-07-21 07:26:14 [INFO ]  )
2022-07-21 07:26:14 [INFO ]  Total number of parameters: 360180

[19]:
debias_model.train_self(
    train_generator = customized_training_generator,
    dev_generator = customized_validation_generator,
    test_generator = customized_test_generator,
)
2022-07-21 07:26:16 [INFO ]  Epoch:    0 [      0/  49998 ( 0%)]        Loss: 1.6374     Data Time: 0.01s       Train Time: 1.56s
2022-07-21 07:27:29 [INFO ]  Evaluation at Epoch 0
2022-07-21 07:27:31 [INFO ]  Validation accuracy: 96.09 macro_fscore: 96.09     micro_fscore: 96.09     TPR_GAP: 2.95   FPR_GAP: 0.33   PPR_GAP: 0.59
2022-07-21 07:27:31 [INFO ]  Test accuracy: 97.04       macro_fscore: 97.02     micro_fscore: 97.04     TPR_GAP: 2.98   FPR_GAP: 0.20   PPR_GAP: 1.09
2022-07-21 07:27:33 [INFO ]  Epoch:    1 [      0/  49998 ( 0%)]        Loss: -0.6122    Data Time: 0.01s       Train Time: 1.44s
2022-07-21 07:28:46 [INFO ]  Evaluation at Epoch 1
2022-07-21 07:28:48 [INFO ]  Validation accuracy: 97.32 macro_fscore: 97.31     micro_fscore: 97.32     TPR_GAP: 1.28   FPR_GAP: 0.10   PPR_GAP: 0.75
2022-07-21 07:28:48 [INFO ]  Test accuracy: 97.99       macro_fscore: 97.97     micro_fscore: 97.99     TPR_GAP: 0.97   FPR_GAP: 0.10   PPR_GAP: 0.97
2022-07-21 07:28:50 [INFO ]  Epoch:    2 [      0/  49998 ( 0%)]        Loss: -0.6411    Data Time: 0.01s       Train Time: 1.48s
2022-07-21 07:30:02 [INFO ]  Evaluation at Epoch 2
2022-07-21 07:30:05 [INFO ]  Validation accuracy: 97.92 macro_fscore: 97.91     micro_fscore: 97.92     TPR_GAP: 1.31   FPR_GAP: 0.12   PPR_GAP: 0.80
2022-07-21 07:30:05 [INFO ]  Test accuracy: 98.38       macro_fscore: 98.37     micro_fscore: 98.38     TPR_GAP: 0.96   FPR_GAP: 0.09   PPR_GAP: 1.00
2022-07-21 07:30:07 [INFO ]  Epoch:    3 [      0/  49998 ( 0%)]        Loss: -0.6523    Data Time: 0.01s       Train Time: 1.76s
2022-07-21 07:31:20 [INFO ]  Evaluation at Epoch 3
2022-07-21 07:31:23 [INFO ]  Validation accuracy: 98.17 macro_fscore: 98.17     micro_fscore: 98.17     TPR_GAP: 0.94   FPR_GAP: 0.10   PPR_GAP: 0.67
2022-07-21 07:31:23 [INFO ]  Test accuracy: 98.52       macro_fscore: 98.51     micro_fscore: 98.52     TPR_GAP: 1.25   FPR_GAP: 0.09   PPR_GAP: 1.00
2022-07-21 07:31:25 [INFO ]  Epoch:    4 [      0/  49998 ( 0%)]        Loss: -0.6713    Data Time: 0.01s       Train Time: 1.46s
2022-07-21 07:32:38 [INFO ]  Evaluation at Epoch 4
2022-07-21 07:32:41 [INFO ]  Validation accuracy: 98.57 macro_fscore: 98.56     micro_fscore: 98.57     TPR_GAP: 0.74   FPR_GAP: 0.09   PPR_GAP: 0.75
2022-07-21 07:32:41 [INFO ]  Test accuracy: 98.78       macro_fscore: 98.77     micro_fscore: 98.78     TPR_GAP: 0.46   FPR_GAP: 0.08   PPR_GAP: 0.93
2022-07-21 07:32:42 [INFO ]  Epoch:    5 [      0/  49998 ( 0%)]        Loss: -0.6775    Data Time: 0.01s       Train Time: 1.45s
2022-07-21 07:33:58 [INFO ]  Epochs since last improvement: 1
2022-07-21 07:33:58 [INFO ]  Evaluation at Epoch 5
2022-07-21 07:34:01 [INFO ]  Validation accuracy: 98.54 macro_fscore: 98.53     micro_fscore: 98.54     TPR_GAP: 0.78   FPR_GAP: 0.06   PPR_GAP: 0.73
2022-07-21 07:34:01 [INFO ]  Test accuracy: 98.76       macro_fscore: 98.76     micro_fscore: 98.76     TPR_GAP: 0.74   FPR_GAP: 0.09   PPR_GAP: 0.88
2022-07-21 07:34:02 [INFO ]  Epoch:    6 [      0/  49998 ( 0%)]        Loss: -0.6868    Data Time: 0.01s       Train Time: 1.49s
2022-07-21 07:35:16 [INFO ]  Epochs since last improvement: 2
2022-07-21 07:35:16 [INFO ]  Evaluation at Epoch 6
2022-07-21 07:35:18 [INFO ]  Validation accuracy: 98.60 macro_fscore: 98.59     micro_fscore: 98.60     TPR_GAP: 0.54   FPR_GAP: 0.06   PPR_GAP: 0.73
2022-07-21 07:35:18 [INFO ]  Test accuracy: 98.74       macro_fscore: 98.73     micro_fscore: 98.74     TPR_GAP: 0.80   FPR_GAP: 0.07   PPR_GAP: 0.97
2022-07-21 07:35:20 [INFO ]  Epoch:    7 [      0/  49998 ( 0%)]        Loss: -0.6850    Data Time: 0.01s       Train Time: 1.48s
2022-07-21 07:36:32 [INFO ]  Evaluation at Epoch 7
2022-07-21 07:36:35 [INFO ]  Validation accuracy: 98.65 macro_fscore: 98.64     micro_fscore: 98.65     TPR_GAP: 0.64   FPR_GAP: 0.08   PPR_GAP: 0.75
2022-07-21 07:36:35 [INFO ]  Test accuracy: 98.82       macro_fscore: 98.81     micro_fscore: 98.82     TPR_GAP: 0.39   FPR_GAP: 0.06   PPR_GAP: 0.95
2022-07-21 07:36:36 [INFO ]  Epoch:    8 [      0/  49998 ( 0%)]        Loss: -0.6916    Data Time: 0.01s       Train Time: 1.47s
2022-07-21 07:37:49 [INFO ]  Epochs since last improvement: 1
2022-07-21 07:37:49 [INFO ]  Evaluation at Epoch 8
2022-07-21 07:37:52 [INFO ]  Validation accuracy: 98.56 macro_fscore: 98.55     micro_fscore: 98.56     TPR_GAP: 0.66   FPR_GAP: 0.08   PPR_GAP: 0.76
2022-07-21 07:37:52 [INFO ]  Test accuracy: 98.79       macro_fscore: 98.78     micro_fscore: 98.79     TPR_GAP: 1.01   FPR_GAP: 0.08   PPR_GAP: 0.96
2022-07-21 07:37:53 [INFO ]  Epoch:    9 [      0/  49998 ( 0%)]        Loss: -0.6902    Data Time: 0.01s       Train Time: 1.46s
2022-07-21 07:39:05 [INFO ]  Epochs since last improvement: 2
2022-07-21 07:39:05 [INFO ]  Evaluation at Epoch 9
2022-07-21 07:39:08 [INFO ]  Validation accuracy: 98.38 macro_fscore: 98.37     micro_fscore: 98.38     TPR_GAP: 0.74   FPR_GAP: 0.11   PPR_GAP: 0.70
2022-07-21 07:39:08 [INFO ]  Test accuracy: 98.72       macro_fscore: 98.71     micro_fscore: 98.72     TPR_GAP: 0.68   FPR_GAP: 0.07   PPR_GAP: 0.94
2022-07-21 07:39:09 [INFO ]  Epoch:   10 [      0/  49998 ( 0%)]        Loss: -0.6885    Data Time: 0.01s       Train Time: 1.44s
2022-07-21 07:40:22 [INFO ]  Epochs since last improvement: 3
2022-07-21 07:40:22 [INFO ]  Evaluation at Epoch 10
2022-07-21 07:40:24 [INFO ]  Validation accuracy: 98.33 macro_fscore: 98.32     micro_fscore: 98.33     TPR_GAP: 0.89   FPR_GAP: 0.08   PPR_GAP: 0.66
2022-07-21 07:40:24 [INFO ]  Test accuracy: 98.71       macro_fscore: 98.70     micro_fscore: 98.71     TPR_GAP: 0.89   FPR_GAP: 0.08   PPR_GAP: 0.96
2022-07-21 07:40:26 [INFO ]  Epoch:   11 [      0/  49998 ( 0%)]        Loss: -0.6860    Data Time: 0.01s       Train Time: 1.46s
2022-07-21 07:41:39 [INFO ]  Epochs since last improvement: 4
2022-07-21 07:41:39 [INFO ]  Evaluation at Epoch 11
2022-07-21 07:41:41 [INFO ]  Validation accuracy: 98.67 macro_fscore: 98.66     micro_fscore: 98.67     TPR_GAP: 0.81   FPR_GAP: 0.11   PPR_GAP: 0.74
2022-07-21 07:41:41 [INFO ]  Test accuracy: 98.83       macro_fscore: 98.82     micro_fscore: 98.83     TPR_GAP: 0.92   FPR_GAP: 0.10   PPR_GAP: 0.90
2022-07-21 07:41:43 [INFO ]  Epoch:   12 [      0/  49998 ( 0%)]        Loss: -0.6909    Data Time: 0.01s       Train Time: 1.46s
2022-07-21 07:42:54 [INFO ]  Epochs since last improvement: 5
2022-07-21 07:42:54 [INFO ]  Evaluation at Epoch 12
2022-07-21 07:42:57 [INFO ]  Validation accuracy: 98.61 macro_fscore: 98.60     micro_fscore: 98.61     TPR_GAP: 0.74   FPR_GAP: 0.09   PPR_GAP: 0.73
2022-07-21 07:42:57 [INFO ]  Test accuracy: 98.90       macro_fscore: 98.89     micro_fscore: 98.90     TPR_GAP: 0.64   FPR_GAP: 0.06   PPR_GAP: 0.96