Skip to content
Odinsynth
GitHub

Training

We use a transformer-based model. We start from a pre-trained BERT model and fine-tuned on our task. The underlying model was modified to accomodate for the particularities of our task. Among others, we added some special tags in the vocabulary, we added two new types of token_type_id (we differentiate between the first tree, the second tree, the sentence highlgihted and the sentence unhighlighted). We provide weights for 5 models (2_128, 4_256, 4_512, 8_512, bert-base).

There are 3 ways to train the underlying transformer model:

  • Binary classification
    • In this case we predict whether the next node of the current node is "correct" or not. Correctness is determined by the oracle. If the respective next node is part of the sequence towards the solution, then we train to predict 1, otherwise we train to predict 0. Notice that we only work with a solution, although a problem has multiple potential solutions
  • Margin loss
    • In this case we aim to predict a score for the "correct" next node that is at least higher than the rest of the nodes
  • Reinforcement Learning
    • In this case we train the model with the reinforcement learning methodology. We used a DQN and did not use the oracle.

When we predict, we feed the model a concatenation of [current_node, next_node, sentence]. In the case of multiple sentences, we average at the end. We feed the current_node and next_node by linearizing the underlying tree.

Training procedure

Inside the python folder, you can run:

1python train.py --config-file configs/config1.json --model-name <model_name> --max-epochs 3 --load-from-arrow-dir <data_path> -nsvs 1000 --checkpoint-prefix <prefix> --train-batch-size 256 -agb 1 --val-batch-size 256 --save-best-in last1.ckpt

Note that:

1- `<model_name>` has to work with huggingface transformers library. If the value of this parameter works with `BertModel.from_pretrained(model_name)`, everything will work. Example of model_name value: `"google/bert_uncased_L-12_H-768_A-12"`
2- `<data_path>` has to work with huggingface datasets library. The way it is used is `datasets.load_from_disk(<data_path>)`. The reason for this is to handle big datasets. Note that this implies that the data was already preprocessed and loaded into datasets before being saved.
3- `<prefix>` is used for logging purposes only

Notes on how the data looks like: Format: <words_in_sentences_separated_by_space>\t<highlighted_start_(inclusive)>\t<highlighted_end_(noninclusive)>\t<current_pattern>\t<next_potential_pattern>\t<is_it_correct_or_not>

An example of two lines in the unrolled data:

1He then uses that trocar to initiate the step system of placement , using the same incision and smaller initial trocar to place the larger one in a step-like fashion . 9 12 □ □ □ 1
2He then uses that trocar to initiate the step system of placement , using the same incision and smaller initial trocar to place the larger one in a step-like fashion . 9 12 □ [□] 0

We use this type of "unrolled" data due to the way we view this task: as a binary classificaiton one