Adding Customized NN Architecture

Classification Head

Our current MLP implementation (fairlib/src/networks/classifier) can be used as a classification head for different backbone nets. All our methods such as balanced training and adversarial training will be supported for the new model.

Customized model architecture

Take a look at the following example, we use the BERT model as the feature extracting network, i.e., extracting sentence representations from the BERT, and then use the extracted features as the input to the MLP classifier to make predictions.

We only need to define three functions: __init__, which is used to init the model with pretrained BERT parameters, MLP classifier, and optimizers; forward, which is the same to before where we extract sentence representations then use MLP to make predictions; and hidden, which is used to get hidden representations for adversarial training.

from transformers import BertModel
from fairlib.networks.classifier import MLP

class BERTClassifier(BaseModel):
    model_name = 'bert-base-cased'

    def __init__(self, args):
        super(BERTClassifier, self).__init__()
        self.args = args

        # Load pretrained model parameters.
        self.bert = BertModel.from_pretrained(self.model_name)

        # Init the classification head 
        self.classifier = MLP(args)

        # Init optimizers, criterions, etc.
        self.init_for_training()

    def forward(self, input_data, group_label = None):
        # Extract sentence representations from bert
        bert_output = self.bert(input_data)[1]

        # Make predictions
        return self.classifier(bert_output, group_label)
    
    def hidden(self, input_data, group_label = None):
        # Extract sentence representations from bert
        bert_output = self.bert(input_data)[1]

        # Make predictions
        return self.classifier.hidden(bert_output, group_label)

Register Model

the model architecture is indicated by --encoder_architecture, so we will need to handle different values of this argument. Specifically, we need to modify the get_main_model function in fairlib/src/networks/__init__.py to support new models.

def get_main_model(args):
    # Add the new model name here.
    assert args.encoder_architecture in ["Fixed", "BERT", "DeepMoji", "NEW_MODEL_NAME"], "Not implemented"

    if args.encoder_architecture == "Fixed":
        model = MLP(args)
    elif args.encoder_architecture == "BERT":
        model = BERTClassifier(args)
    # Init the model 
    elif args.encoder_architecture == "NEW_MODEL_NAME":
        model = MODEL(args)
    else:
        raise "not implemented yet"

Register the Dataloader

Since different models have their own mapping from tokens to numerical ids. We need to handle this in the dataloader.

Firstly, we need to init the tokenizer in fairlib/src/dataloaders/encoder.py, for example,

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

Next, we need to modify the corresponding dataloader to return the idx of input texts. Please take a look at the Bios loader in fairlib/src/dataloaders/loaders.py for detailed examples.

Noticing that, to avoid encoding text to idx repeatedly, we could pre-calculate the mapped idx for the desired model, and load from file to save time.

Extensions

class BiLSTMPOSTagger(nn.Module):
    def __init__(self, 
                input_dim, 
                embedding_dim, 
                hidden_dim, 
                output_dim, 
                n_layers, 
                bidirectional, 
                dropout, 
                pad_idx):
        
        super().__init__()
        
        
        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
        
        self.lstm = nn.LSTM(embedding_dim, 
                            hidden_dim, 
                            num_layers = n_layers, 
                            bidirectional = bidirectional,
                            dropout = dropout if n_layers > 1 else 0)
        
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)

        # args.input_dim = hidden_dim * 2 if bidirectional else hidden_dim
        # args.output_dim = hidden_dim * 2 if bidirectional else output_dim
        # args.n_hidden = 0
        # self.fc = MLP(args)

        self.dropout = nn.Dropout(dropout)

    def forward(self, text):

        #text = [sent len, batch size]
        
        #pass text through embedding layer
        embedded = self.embedding(text)
        
        #embedded = [sent len, batch size, emb dim]
        
        #pass embeddings into LSTM
        outputs, (hidden, cell) = self.lstm(embedded)
        
        #outputs holds the backward and forward hidden states in the final layer
        #hidden and cell are the backward and forward hidden and cell states at the final time-step
        
        #output = [sent len, batch size, hid dim * n directions]
        #hidden/cell = [n layers * n directions, batch size, hid dim]
        
        #we use our outputs to make a prediction of what the tag should be
        predictions = self.fc(self.dropout(outputs))
        
        #predictions = [sent len, batch size, output dim]
        
        return predictions