Customized models and datasets for CV
In this tutorial we will: - Show how to add a model for image classification - Show how to add a dataloader with image preprocessing
We will be using the colored MNIST dataset, please see the origianl demo for more details.
Installation
Again, the first step will be installing our libarary
[1]:
!pip install -q fairlib
|████████████████████████████████| 85 kB 1.2 MB/s
|████████████████████████████████| 256 kB 34.4 MB/s
|████████████████████████████████| 880 kB 60.2 MB/s
|████████████████████████████████| 4.4 MB 57.8 MB/s
|████████████████████████████████| 101 kB 12.4 MB/s
|████████████████████████████████| 596 kB 36.3 MB/s
|████████████████████████████████| 6.6 MB 12.4 MB/s
Building wheel for sacremoses (setup.py) ... done
[2]:
import fairlib
Explore Datasets
[3]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torchvision import transforms
from torchvision import datasets
import torchvision.datasets.utils as dataset_utils
[4]:
def color_grayscale_arr(arr, red=True):
"""Converts grayscale image to either red or green"""
assert arr.ndim == 2
dtype = arr.dtype
h, w = arr.shape
arr = np.reshape(arr, [h, w, 1])
if red:
arr = np.concatenate([
arr,
np.zeros((h, w, 2), dtype=dtype)], axis=2)
else:
arr = np.concatenate([
np.zeros((h, w, 1), dtype=dtype),
arr,
np.zeros((h, w, 1), dtype=dtype)], axis=2)
return arr
[5]:
root = "./data"
[6]:
train_mnist = datasets.mnist.MNIST(root=root, train=True, download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
[7]:
fig, ax = plt.subplots(1,3)
ax[0].imshow(train_mnist[0][0]) # Original
ax[1].imshow(color_grayscale_arr(np.array(train_mnist[0][0]))) # Red
ax[2].imshow(color_grayscale_arr(np.array(train_mnist[0][0]), False)) # Green
[7]:
<matplotlib.image.AxesImage at 0x7fbdcef158d0>
[8]:
from fairlib import datasets
[9]:
datasets.prepare_dataset("coloredmnist", "coloredmnist")
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to coloredmnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting coloredmnist/MNIST/raw/train-images-idx3-ubyte.gz to coloredmnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting coloredmnist/MNIST/raw/train-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to coloredmnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting coloredmnist/MNIST/raw/t10k-images-idx3-ubyte.gz to coloredmnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting coloredmnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to coloredmnist/MNIST/raw
Converting image 0/50000
Converting image 5000/50000
Converting image 10000/50000
Converting image 15000/50000
Converting image 20000/50000
Converting image 25000/50000
Converting image 30000/50000
Converting image 35000/50000
Converting image 40000/50000
Converting image 45000/50000
Converting image 0/10000
Converting image 5000/10000
Converting image 0/10000
Converting image 5000/10000
[10]:
from fairlib import networks, BaseOptions, dataloaders
[11]:
Shared_options = {
# The name of the dataset, correponding dataloader will be used,
"dataset": "MNIST",
# Specifiy the path to the input data
"data_dir": "./coloredmnist",
# Device for computing, -1 is the cpu
"device_id": -1,
# The default path for saving experimental results
"results_dir": r"results",
# The same as the dataset
"project_dir": r"dev",
# We will focusing on TPR GAP, implying the Equalized Odds for binay classification.
"GAP_metric_name": "TPR_GAP",
# The overall performance will be measured as accuracy
"Performance_metric_name": "accuracy",
# Model selections are based on DTO
"selection_criterion": "DTO",
# Default dirs for saving checkpoints
"checkpoint_dir": "models",
"checkpoint_name": "checkpoint_epoch",
"n_jobs": 1,
}
Customizing NN Architectures
[12]:
class ConvNet(networks.utils.BaseModel):
def __init__(self, args):
super(ConvNet, self).__init__()
self.args = args
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.classifier = networks.classifier.MLP(args)
self.init_for_training()
def forward(self, input_data, group_label = None):
x = input_data
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
return self.classifier(x, group_label)
def hidden(self, input_data, group_label = None):
x = input_data
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
return self.classifier.hidden(x, group_label)
[13]:
args = {
"dataset":Shared_options["dataset"],
"data_dir":Shared_options["data_dir"],
"device_id":Shared_options["device_id"],
# Give a name to the exp, which will be used in the path
"exp_id":"vanilla",
"emb_size": 4*4*50,
"num_classes": 10,
"num_groups": 2,
}
# Init the argument
options = BaseOptions()
state = options.get_state(args=args, silence=True)
INFO:root:Unexpected args: ['-f', '/root/.local/share/jupyter/runtime/kernel-a7e24054-f7e0-4930-9df2-bfcd40fedbad.json']
INFO:root:Logging to ./results/dev/MNIST/vanilla/output.log
2022-07-21 07:19:08 [INFO ] ======================================== 2022-07-21 07:19:08 ========================================
2022-07-21 07:19:08 [INFO ] Base directory is ./results/dev/MNIST/vanilla
2022-07-21 07:19:09 [INFO ] Exception type : AssertionError
2022-07-21 07:19:09 [INFO ] Exception message : Not implemented
2022-07-21 07:19:09 [INFO ] Stack trace : ['File : /usr/local/lib/python3.7/dist-packages/fairlib/src/base_options.py , Line : 486, Func.Name : set_state, Message : train_iterator, dev_iterator, test_iterator = dataloaders.get_dataloaders(state)', 'File : /usr/local/lib/python3.7/dist-packages/fairlib/src/dataloaders/__init__.py , Line : 40, Func.Name : get_dataloaders, Message : ], "Not implemented"']
2022-07-21 07:19:09 [INFO ] dataloaders need to be initialized!
Customizing Dataloader
[14]:
class CustomizedDataset(dataloaders.utils.BaseDataset):
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))])
def load_data(self):
self.data_dir = os.path.join(self.args.data_dir, "colored_MNIST_{}.pt".format(self.split))
data = torch.load(self.data_dir)
self.X = [self.transform(_img) for _img in data[0]]
self.y = data[1]
self.protected_label = data[2]
[15]:
customized_train_data = CustomizedDataset(args=state, split="train")
customized_dev_data = CustomizedDataset(args=state, split="dev")
customized_test_data = CustomizedDataset(args=state, split="test")
# DataLoader Parameters
tran_dataloader_params = {
'batch_size': state.batch_size,
'shuffle': True,
'num_workers': state.num_workers}
eval_dataloader_params = {
'batch_size': state.test_batch_size,
'shuffle': False,
'num_workers': state.num_workers}
# init dataloader
customized_training_generator = torch.utils.data.DataLoader(customized_train_data, **tran_dataloader_params)
customized_validation_generator = torch.utils.data.DataLoader(customized_dev_data, **eval_dataloader_params)
customized_test_generator = torch.utils.data.DataLoader(customized_test_data, **eval_dataloader_params)
Loaded data shapes: (50000,), (50000,), (50000,)
Loaded data shapes: (10000,), (10000,), (10000,)
Loaded data shapes: (10000,), (10000,), (10000,)
Training vanilla model without debiasing
[16]:
model = ConvNet(state)
2022-07-21 07:19:20 [INFO ] MLP(
2022-07-21 07:19:20 [INFO ] (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:19:20 [INFO ] (AF): Tanh()
2022-07-21 07:19:20 [INFO ] (hidden_layers): ModuleList(
2022-07-21 07:19:20 [INFO ] (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ] (1): Tanh()
2022-07-21 07:19:20 [INFO ] (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ] (3): Tanh()
2022-07-21 07:19:20 [INFO ] )
2022-07-21 07:19:20 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:19:20 [INFO ] )
2022-07-21 07:19:20 [INFO ] Total number of parameters: 333610
2022-07-21 07:19:20 [INFO ] ConvNet(
2022-07-21 07:19:20 [INFO ] (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:19:20 [INFO ] (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:19:20 [INFO ] (classifier): MLP(
2022-07-21 07:19:20 [INFO ] (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:19:20 [INFO ] (AF): Tanh()
2022-07-21 07:19:20 [INFO ] (hidden_layers): ModuleList(
2022-07-21 07:19:20 [INFO ] (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ] (1): Tanh()
2022-07-21 07:19:20 [INFO ] (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:19:20 [INFO ] (3): Tanh()
2022-07-21 07:19:20 [INFO ] )
2022-07-21 07:19:20 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:19:20 [INFO ] )
2022-07-21 07:19:20 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:19:20 [INFO ] )
2022-07-21 07:19:20 [INFO ] Total number of parameters: 360180
[17]:
model.train_self(
train_generator = customized_training_generator,
dev_generator = customized_validation_generator,
test_generator = customized_test_generator,
)
2022-07-21 07:19:20 [INFO ] Epoch: 0 [ 0/ 50000 ( 0%)] Loss: 2.3093 Data Time: 0.01s Train Time: 0.77s
2022-07-21 07:19:51 [INFO ] Evaluation at Epoch 0
2022-07-21 07:19:53 [INFO ] Validation accuracy: 95.53 macro_fscore: 95.52 micro_fscore: 95.53 TPR_GAP: 4.69 FPR_GAP: 0.40 PPR_GAP: 0.74
2022-07-21 07:19:53 [INFO ] Test accuracy: 96.45 macro_fscore: 96.42 micro_fscore: 96.45 TPR_GAP: 3.58 FPR_GAP: 0.37 PPR_GAP: 0.72
2022-07-21 07:19:54 [INFO ] Epoch: 1 [ 0/ 50000 ( 0%)] Loss: 0.1037 Data Time: 0.01s Train Time: 0.59s
2022-07-21 07:20:25 [INFO ] Evaluation at Epoch 1
2022-07-21 07:20:28 [INFO ] Validation accuracy: 97.29 macro_fscore: 97.29 micro_fscore: 97.29 TPR_GAP: 2.54 FPR_GAP: 0.27 PPR_GAP: 0.67
2022-07-21 07:20:28 [INFO ] Test accuracy: 97.42 macro_fscore: 97.40 micro_fscore: 97.42 TPR_GAP: 3.33 FPR_GAP: 0.35 PPR_GAP: 0.73
2022-07-21 07:20:28 [INFO ] Epoch: 2 [ 0/ 50000 ( 0%)] Loss: 0.0530 Data Time: 0.01s Train Time: 0.58s
2022-07-21 07:21:01 [INFO ] Evaluation at Epoch 2
2022-07-21 07:21:04 [INFO ] Validation accuracy: 97.48 macro_fscore: 97.48 micro_fscore: 97.48 TPR_GAP: 3.68 FPR_GAP: 0.39 PPR_GAP: 0.71
2022-07-21 07:21:04 [INFO ] Test accuracy: 97.97 macro_fscore: 97.96 micro_fscore: 97.97 TPR_GAP: 3.30 FPR_GAP: 0.31 PPR_GAP: 0.98
2022-07-21 07:21:05 [INFO ] Epoch: 3 [ 0/ 50000 ( 0%)] Loss: 0.0368 Data Time: 0.01s Train Time: 0.58s
2022-07-21 07:21:36 [INFO ] Evaluation at Epoch 3
2022-07-21 07:21:38 [INFO ] Validation accuracy: 98.31 macro_fscore: 98.31 micro_fscore: 98.31 TPR_GAP: 2.50 FPR_GAP: 0.30 PPR_GAP: 0.71
2022-07-21 07:21:38 [INFO ] Test accuracy: 98.45 macro_fscore: 98.44 micro_fscore: 98.45 TPR_GAP: 2.05 FPR_GAP: 0.23 PPR_GAP: 0.86
2022-07-21 07:21:39 [INFO ] Epoch: 4 [ 0/ 50000 ( 0%)] Loss: 0.0100 Data Time: 0.01s Train Time: 0.57s
2022-07-21 07:22:09 [INFO ] Evaluation at Epoch 4
2022-07-21 07:22:12 [INFO ] Validation accuracy: 98.27 macro_fscore: 98.27 micro_fscore: 98.27 TPR_GAP: 2.00 FPR_GAP: 0.18 PPR_GAP: 0.65
2022-07-21 07:22:12 [INFO ] Test accuracy: 98.44 macro_fscore: 98.43 micro_fscore: 98.44 TPR_GAP: 1.79 FPR_GAP: 0.20 PPR_GAP: 0.81
2022-07-21 07:22:12 [INFO ] Epoch: 5 [ 0/ 50000 ( 0%)] Loss: 0.0130 Data Time: 0.01s Train Time: 0.61s
2022-07-21 07:22:42 [INFO ] Evaluation at Epoch 5
2022-07-21 07:22:45 [INFO ] Validation accuracy: 98.46 macro_fscore: 98.46 micro_fscore: 98.46 TPR_GAP: 1.47 FPR_GAP: 0.22 PPR_GAP: 0.61
2022-07-21 07:22:45 [INFO ] Test accuracy: 98.52 macro_fscore: 98.51 micro_fscore: 98.52 TPR_GAP: 1.47 FPR_GAP: 0.16 PPR_GAP: 0.83
2022-07-21 07:22:45 [INFO ] Epoch: 6 [ 0/ 50000 ( 0%)] Loss: 0.0104 Data Time: 0.01s Train Time: 0.56s
2022-07-21 07:23:16 [INFO ] Evaluation at Epoch 6
2022-07-21 07:23:18 [INFO ] Validation accuracy: 98.57 macro_fscore: 98.57 micro_fscore: 98.57 TPR_GAP: 1.61 FPR_GAP: 0.19 PPR_GAP: 0.72
2022-07-21 07:23:18 [INFO ] Test accuracy: 98.74 macro_fscore: 98.73 micro_fscore: 98.74 TPR_GAP: 1.45 FPR_GAP: 0.15 PPR_GAP: 0.80
2022-07-21 07:23:19 [INFO ] Epoch: 7 [ 0/ 50000 ( 0%)] Loss: 0.0070 Data Time: 0.01s Train Time: 0.60s
2022-07-21 07:23:49 [INFO ] Epochs since last improvement: 1
2022-07-21 07:23:49 [INFO ] Evaluation at Epoch 7
2022-07-21 07:23:51 [INFO ] Validation accuracy: 98.47 macro_fscore: 98.46 micro_fscore: 98.47 TPR_GAP: 1.83 FPR_GAP: 0.16 PPR_GAP: 0.68
2022-07-21 07:23:51 [INFO ] Test accuracy: 98.55 macro_fscore: 98.53 micro_fscore: 98.55 TPR_GAP: 1.66 FPR_GAP: 0.16 PPR_GAP: 0.77
2022-07-21 07:23:52 [INFO ] Epoch: 8 [ 0/ 50000 ( 0%)] Loss: 0.0069 Data Time: 0.01s Train Time: 0.57s
2022-07-21 07:24:22 [INFO ] Epochs since last improvement: 2
2022-07-21 07:24:22 [INFO ] Evaluation at Epoch 8
2022-07-21 07:24:24 [INFO ] Validation accuracy: 98.56 macro_fscore: 98.55 micro_fscore: 98.56 TPR_GAP: 1.83 FPR_GAP: 0.14 PPR_GAP: 0.58
2022-07-21 07:24:24 [INFO ] Test accuracy: 98.65 macro_fscore: 98.64 micro_fscore: 98.65 TPR_GAP: 1.80 FPR_GAP: 0.18 PPR_GAP: 0.82
2022-07-21 07:24:25 [INFO ] Epoch: 9 [ 0/ 50000 ( 0%)] Loss: 0.0015 Data Time: 0.01s Train Time: 0.57s
2022-07-21 07:24:55 [INFO ] Epochs since last improvement: 3
2022-07-21 07:24:55 [INFO ] Evaluation at Epoch 9
2022-07-21 07:24:57 [INFO ] Validation accuracy: 98.55 macro_fscore: 98.55 micro_fscore: 98.55 TPR_GAP: 1.65 FPR_GAP: 0.20 PPR_GAP: 0.66
2022-07-21 07:24:57 [INFO ] Test accuracy: 98.67 macro_fscore: 98.66 micro_fscore: 98.67 TPR_GAP: 1.58 FPR_GAP: 0.17 PPR_GAP: 0.82
2022-07-21 07:24:58 [INFO ] Epoch: 10 [ 0/ 50000 ( 0%)] Loss: 0.0157 Data Time: 0.01s Train Time: 0.56s
2022-07-21 07:25:28 [INFO ] Epochs since last improvement: 4
2022-07-21 07:25:28 [INFO ] Evaluation at Epoch 10
2022-07-21 07:25:30 [INFO ] Validation accuracy: 98.41 macro_fscore: 98.39 micro_fscore: 98.41 TPR_GAP: 2.19 FPR_GAP: 0.24 PPR_GAP: 0.70
2022-07-21 07:25:30 [INFO ] Test accuracy: 98.55 macro_fscore: 98.53 micro_fscore: 98.55 TPR_GAP: 1.47 FPR_GAP: 0.18 PPR_GAP: 0.82
2022-07-21 07:25:31 [INFO ] Epoch: 11 [ 0/ 50000 ( 0%)] Loss: 0.0038 Data Time: 0.01s Train Time: 0.56s
2022-07-21 07:26:00 [INFO ] Epochs since last improvement: 5
2022-07-21 07:26:00 [INFO ] Evaluation at Epoch 11
2022-07-21 07:26:03 [INFO ] Validation accuracy: 98.69 macro_fscore: 98.69 micro_fscore: 98.69 TPR_GAP: 1.32 FPR_GAP: 0.14 PPR_GAP: 0.64
2022-07-21 07:26:03 [INFO ] Test accuracy: 98.84 macro_fscore: 98.83 micro_fscore: 98.84 TPR_GAP: 1.21 FPR_GAP: 0.13 PPR_GAP: 0.87
Improveing Fairness
[18]:
debiasing_args = {
"dataset":Shared_options["dataset"],
"data_dir":Shared_options["data_dir"],
"device_id":Shared_options["device_id"],
# Give a name to the exp, which will be used in the path
"exp_id":"BT_Adv",
"emb_size": 4*4*50,
"num_classes": 10,
"num_groups": 2,
# Perform adversarial training if True
"adv_debiasing":True,
# Specify the hyperparameters for Balanced Training
"BT":"Resampling",
"BTObj":"EO",
}
# Init the argument
debias_options = BaseOptions()
debias_state = debias_options.get_state(args=debiasing_args, silence=True)
customized_train_data = CustomizedDataset(args=debias_state, split="train")
customized_dev_data = CustomizedDataset(args=debias_state, split="dev")
customized_test_data = CustomizedDataset(args=debias_state, split="test")
# DataLoader Parameters
tran_dataloader_params = {
'batch_size': state.batch_size,
'shuffle': True,
'num_workers': state.num_workers}
eval_dataloader_params = {
'batch_size': state.test_batch_size,
'shuffle': False,
'num_workers': state.num_workers}
# init dataloader
customized_training_generator = torch.utils.data.DataLoader(customized_train_data, **tran_dataloader_params)
customized_validation_generator = torch.utils.data.DataLoader(customized_dev_data, **eval_dataloader_params)
customized_test_generator = torch.utils.data.DataLoader(customized_test_data, **eval_dataloader_params)
debias_model = ConvNet(debias_state)
2022-07-21 07:26:03 [INFO ] Unexpected args: ['-f', '/root/.local/share/jupyter/runtime/kernel-a7e24054-f7e0-4930-9df2-bfcd40fedbad.json']
2022-07-21 07:26:03 [INFO ] Logging to ./results/dev/MNIST/BT_Adv/output.log
2022-07-21 07:26:03 [INFO ] ======================================== 2022-07-21 07:26:03 ========================================
2022-07-21 07:26:03 [INFO ] Base directory is ./results/dev/MNIST/BT_Adv
2022-07-21 07:26:03 [INFO ] Exception type : AssertionError
2022-07-21 07:26:03 [INFO ] Exception message : Not implemented
2022-07-21 07:26:03 [INFO ] Stack trace : ['File : /usr/local/lib/python3.7/dist-packages/fairlib/src/base_options.py , Line : 486, Func.Name : set_state, Message : train_iterator, dev_iterator, test_iterator = dataloaders.get_dataloaders(state)', 'File : /usr/local/lib/python3.7/dist-packages/fairlib/src/dataloaders/__init__.py , Line : 40, Func.Name : get_dataloaders, Message : ], "Not implemented"']
2022-07-21 07:26:03 [INFO ] dataloaders need to be initialized!
2022-07-21 07:26:03 [INFO ] SubDiscriminator(
2022-07-21 07:26:03 [INFO ] (grad_rev): GradientReversal()
2022-07-21 07:26:03 [INFO ] (output_layer): Linear(in_features=300, out_features=2, bias=True)
2022-07-21 07:26:03 [INFO ] (AF): ReLU()
2022-07-21 07:26:03 [INFO ] (hidden_layers): ModuleList(
2022-07-21 07:26:03 [INFO ] (0): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:03 [INFO ] (1): ReLU()
2022-07-21 07:26:03 [INFO ] (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:03 [INFO ] (3): ReLU()
2022-07-21 07:26:03 [INFO ] )
2022-07-21 07:26:03 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:26:03 [INFO ] )
2022-07-21 07:26:03 [INFO ] Total number of parameters: 181202
2022-07-21 07:26:03 [INFO ] Discriminator built!
Loaded data shapes: (49998,), (49998,), (49998,)
Loaded data shapes: (10000,), (10000,), (10000,)
Loaded data shapes: (10000,), (10000,), (10000,)
2022-07-21 07:26:14 [INFO ] MLP(
2022-07-21 07:26:14 [INFO ] (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:26:14 [INFO ] (AF): Tanh()
2022-07-21 07:26:14 [INFO ] (hidden_layers): ModuleList(
2022-07-21 07:26:14 [INFO ] (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ] (1): Tanh()
2022-07-21 07:26:14 [INFO ] (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ] (3): Tanh()
2022-07-21 07:26:14 [INFO ] )
2022-07-21 07:26:14 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:26:14 [INFO ] )
2022-07-21 07:26:14 [INFO ] Total number of parameters: 333610
2022-07-21 07:26:14 [INFO ] ConvNet(
2022-07-21 07:26:14 [INFO ] (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:26:14 [INFO ] (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
2022-07-21 07:26:14 [INFO ] (classifier): MLP(
2022-07-21 07:26:14 [INFO ] (output_layer): Linear(in_features=300, out_features=10, bias=True)
2022-07-21 07:26:14 [INFO ] (AF): Tanh()
2022-07-21 07:26:14 [INFO ] (hidden_layers): ModuleList(
2022-07-21 07:26:14 [INFO ] (0): Linear(in_features=800, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ] (1): Tanh()
2022-07-21 07:26:14 [INFO ] (2): Linear(in_features=300, out_features=300, bias=True)
2022-07-21 07:26:14 [INFO ] (3): Tanh()
2022-07-21 07:26:14 [INFO ] )
2022-07-21 07:26:14 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:26:14 [INFO ] )
2022-07-21 07:26:14 [INFO ] (criterion): CrossEntropyLoss()
2022-07-21 07:26:14 [INFO ] )
2022-07-21 07:26:14 [INFO ] Total number of parameters: 360180
[19]:
debias_model.train_self(
train_generator = customized_training_generator,
dev_generator = customized_validation_generator,
test_generator = customized_test_generator,
)
2022-07-21 07:26:16 [INFO ] Epoch: 0 [ 0/ 49998 ( 0%)] Loss: 1.6374 Data Time: 0.01s Train Time: 1.56s
2022-07-21 07:27:29 [INFO ] Evaluation at Epoch 0
2022-07-21 07:27:31 [INFO ] Validation accuracy: 96.09 macro_fscore: 96.09 micro_fscore: 96.09 TPR_GAP: 2.95 FPR_GAP: 0.33 PPR_GAP: 0.59
2022-07-21 07:27:31 [INFO ] Test accuracy: 97.04 macro_fscore: 97.02 micro_fscore: 97.04 TPR_GAP: 2.98 FPR_GAP: 0.20 PPR_GAP: 1.09
2022-07-21 07:27:33 [INFO ] Epoch: 1 [ 0/ 49998 ( 0%)] Loss: -0.6122 Data Time: 0.01s Train Time: 1.44s
2022-07-21 07:28:46 [INFO ] Evaluation at Epoch 1
2022-07-21 07:28:48 [INFO ] Validation accuracy: 97.32 macro_fscore: 97.31 micro_fscore: 97.32 TPR_GAP: 1.28 FPR_GAP: 0.10 PPR_GAP: 0.75
2022-07-21 07:28:48 [INFO ] Test accuracy: 97.99 macro_fscore: 97.97 micro_fscore: 97.99 TPR_GAP: 0.97 FPR_GAP: 0.10 PPR_GAP: 0.97
2022-07-21 07:28:50 [INFO ] Epoch: 2 [ 0/ 49998 ( 0%)] Loss: -0.6411 Data Time: 0.01s Train Time: 1.48s
2022-07-21 07:30:02 [INFO ] Evaluation at Epoch 2
2022-07-21 07:30:05 [INFO ] Validation accuracy: 97.92 macro_fscore: 97.91 micro_fscore: 97.92 TPR_GAP: 1.31 FPR_GAP: 0.12 PPR_GAP: 0.80
2022-07-21 07:30:05 [INFO ] Test accuracy: 98.38 macro_fscore: 98.37 micro_fscore: 98.38 TPR_GAP: 0.96 FPR_GAP: 0.09 PPR_GAP: 1.00
2022-07-21 07:30:07 [INFO ] Epoch: 3 [ 0/ 49998 ( 0%)] Loss: -0.6523 Data Time: 0.01s Train Time: 1.76s
2022-07-21 07:31:20 [INFO ] Evaluation at Epoch 3
2022-07-21 07:31:23 [INFO ] Validation accuracy: 98.17 macro_fscore: 98.17 micro_fscore: 98.17 TPR_GAP: 0.94 FPR_GAP: 0.10 PPR_GAP: 0.67
2022-07-21 07:31:23 [INFO ] Test accuracy: 98.52 macro_fscore: 98.51 micro_fscore: 98.52 TPR_GAP: 1.25 FPR_GAP: 0.09 PPR_GAP: 1.00
2022-07-21 07:31:25 [INFO ] Epoch: 4 [ 0/ 49998 ( 0%)] Loss: -0.6713 Data Time: 0.01s Train Time: 1.46s
2022-07-21 07:32:38 [INFO ] Evaluation at Epoch 4
2022-07-21 07:32:41 [INFO ] Validation accuracy: 98.57 macro_fscore: 98.56 micro_fscore: 98.57 TPR_GAP: 0.74 FPR_GAP: 0.09 PPR_GAP: 0.75
2022-07-21 07:32:41 [INFO ] Test accuracy: 98.78 macro_fscore: 98.77 micro_fscore: 98.78 TPR_GAP: 0.46 FPR_GAP: 0.08 PPR_GAP: 0.93
2022-07-21 07:32:42 [INFO ] Epoch: 5 [ 0/ 49998 ( 0%)] Loss: -0.6775 Data Time: 0.01s Train Time: 1.45s
2022-07-21 07:33:58 [INFO ] Epochs since last improvement: 1
2022-07-21 07:33:58 [INFO ] Evaluation at Epoch 5
2022-07-21 07:34:01 [INFO ] Validation accuracy: 98.54 macro_fscore: 98.53 micro_fscore: 98.54 TPR_GAP: 0.78 FPR_GAP: 0.06 PPR_GAP: 0.73
2022-07-21 07:34:01 [INFO ] Test accuracy: 98.76 macro_fscore: 98.76 micro_fscore: 98.76 TPR_GAP: 0.74 FPR_GAP: 0.09 PPR_GAP: 0.88
2022-07-21 07:34:02 [INFO ] Epoch: 6 [ 0/ 49998 ( 0%)] Loss: -0.6868 Data Time: 0.01s Train Time: 1.49s
2022-07-21 07:35:16 [INFO ] Epochs since last improvement: 2
2022-07-21 07:35:16 [INFO ] Evaluation at Epoch 6
2022-07-21 07:35:18 [INFO ] Validation accuracy: 98.60 macro_fscore: 98.59 micro_fscore: 98.60 TPR_GAP: 0.54 FPR_GAP: 0.06 PPR_GAP: 0.73
2022-07-21 07:35:18 [INFO ] Test accuracy: 98.74 macro_fscore: 98.73 micro_fscore: 98.74 TPR_GAP: 0.80 FPR_GAP: 0.07 PPR_GAP: 0.97
2022-07-21 07:35:20 [INFO ] Epoch: 7 [ 0/ 49998 ( 0%)] Loss: -0.6850 Data Time: 0.01s Train Time: 1.48s
2022-07-21 07:36:32 [INFO ] Evaluation at Epoch 7
2022-07-21 07:36:35 [INFO ] Validation accuracy: 98.65 macro_fscore: 98.64 micro_fscore: 98.65 TPR_GAP: 0.64 FPR_GAP: 0.08 PPR_GAP: 0.75
2022-07-21 07:36:35 [INFO ] Test accuracy: 98.82 macro_fscore: 98.81 micro_fscore: 98.82 TPR_GAP: 0.39 FPR_GAP: 0.06 PPR_GAP: 0.95
2022-07-21 07:36:36 [INFO ] Epoch: 8 [ 0/ 49998 ( 0%)] Loss: -0.6916 Data Time: 0.01s Train Time: 1.47s
2022-07-21 07:37:49 [INFO ] Epochs since last improvement: 1
2022-07-21 07:37:49 [INFO ] Evaluation at Epoch 8
2022-07-21 07:37:52 [INFO ] Validation accuracy: 98.56 macro_fscore: 98.55 micro_fscore: 98.56 TPR_GAP: 0.66 FPR_GAP: 0.08 PPR_GAP: 0.76
2022-07-21 07:37:52 [INFO ] Test accuracy: 98.79 macro_fscore: 98.78 micro_fscore: 98.79 TPR_GAP: 1.01 FPR_GAP: 0.08 PPR_GAP: 0.96
2022-07-21 07:37:53 [INFO ] Epoch: 9 [ 0/ 49998 ( 0%)] Loss: -0.6902 Data Time: 0.01s Train Time: 1.46s
2022-07-21 07:39:05 [INFO ] Epochs since last improvement: 2
2022-07-21 07:39:05 [INFO ] Evaluation at Epoch 9
2022-07-21 07:39:08 [INFO ] Validation accuracy: 98.38 macro_fscore: 98.37 micro_fscore: 98.38 TPR_GAP: 0.74 FPR_GAP: 0.11 PPR_GAP: 0.70
2022-07-21 07:39:08 [INFO ] Test accuracy: 98.72 macro_fscore: 98.71 micro_fscore: 98.72 TPR_GAP: 0.68 FPR_GAP: 0.07 PPR_GAP: 0.94
2022-07-21 07:39:09 [INFO ] Epoch: 10 [ 0/ 49998 ( 0%)] Loss: -0.6885 Data Time: 0.01s Train Time: 1.44s
2022-07-21 07:40:22 [INFO ] Epochs since last improvement: 3
2022-07-21 07:40:22 [INFO ] Evaluation at Epoch 10
2022-07-21 07:40:24 [INFO ] Validation accuracy: 98.33 macro_fscore: 98.32 micro_fscore: 98.33 TPR_GAP: 0.89 FPR_GAP: 0.08 PPR_GAP: 0.66
2022-07-21 07:40:24 [INFO ] Test accuracy: 98.71 macro_fscore: 98.70 micro_fscore: 98.71 TPR_GAP: 0.89 FPR_GAP: 0.08 PPR_GAP: 0.96
2022-07-21 07:40:26 [INFO ] Epoch: 11 [ 0/ 49998 ( 0%)] Loss: -0.6860 Data Time: 0.01s Train Time: 1.46s
2022-07-21 07:41:39 [INFO ] Epochs since last improvement: 4
2022-07-21 07:41:39 [INFO ] Evaluation at Epoch 11
2022-07-21 07:41:41 [INFO ] Validation accuracy: 98.67 macro_fscore: 98.66 micro_fscore: 98.67 TPR_GAP: 0.81 FPR_GAP: 0.11 PPR_GAP: 0.74
2022-07-21 07:41:41 [INFO ] Test accuracy: 98.83 macro_fscore: 98.82 micro_fscore: 98.83 TPR_GAP: 0.92 FPR_GAP: 0.10 PPR_GAP: 0.90
2022-07-21 07:41:43 [INFO ] Epoch: 12 [ 0/ 49998 ( 0%)] Loss: -0.6909 Data Time: 0.01s Train Time: 1.46s
2022-07-21 07:42:54 [INFO ] Epochs since last improvement: 5
2022-07-21 07:42:54 [INFO ] Evaluation at Epoch 12
2022-07-21 07:42:57 [INFO ] Validation accuracy: 98.61 macro_fscore: 98.60 micro_fscore: 98.61 TPR_GAP: 0.74 FPR_GAP: 0.09 PPR_GAP: 0.73
2022-07-21 07:42:57 [INFO ] Test accuracy: 98.90 macro_fscore: 98.89 micro_fscore: 98.90 TPR_GAP: 0.64 FPR_GAP: 0.06 PPR_GAP: 0.96