ELECTRA Encoder Language Model

The ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately) model was described in “ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators” published in March 2020.

It proposed a new pre-training method called “replaced token detection” for encoders, that is more sample efficient than the BERT style masked language modeling (MLM). The model needs to make a binary prediction for every token whether it is an original, or a replacement that was generated by a co-trained generative model.

BERT can be viewed as learning denoising autoencoders, because it selects a small subset of unlabeled input, mask the identities of those tokens, then train the model to recover the original input.

  • While this is more effective than conventional language model pretraining (i.e. decoder-only Transformers) due to learning bidirectional representations, but BERT is not sample efficient because it only learns from 15% of the (masked) tokens per example.
  • There is also a pretraining vs inference mismatch: the model sees artificial [MASK] tokens during pretraining, but not when fine-tuning on downstream tasks.

Training Details

  • Trains a generator G and a discriminator D. Each is a Transformer encoder that maps input $\textbf{x} = [x_1, \ldots, x_n]$ into hidden representations $h(\textbf{x}) = [h_1, \ldots, h_n]$.
  • Let $\mathbf{x}_t$ represent the input sequence with masked tokens. For each [MASK] position $t$, the generator outputs probability of generating token $x_t$ as follows, where $e$ denotes token embeddings: \(p_{G}(x_t|\textbf{x}_t) = \frac{\text{exp}(e(x_t)^T h_G(\textbf{x}_t))}{\sum_{x'} \text{exp}(e(x')^T h_G(\textbf{x}_t))}\)
  • On the discriminator $D$, for a given position $t$, it predicts whether the token $x_t$ is ground truth or sampled from the generator. That is, given a sequence containing masked tokens, we create a corrupted example $\mathbf{x}^{\text{corrupted}}$ by replacing the masked-out tokens with generator samples and train the discriminator to predict (for each token) whether it is the original ground truth token: \(D(\mathbf{x}^{\text{corrupted}}, t) = \text{sigmoid}(w^T h_D(\textbf{x}^{\text{corrupted}}))\)
    • If the generator happens to generate the correct token, that token is considered “real” instead of “fake”.

Evaluation

  • On GLUE and SQuAD, the authors show that ELECTRA-Base performs better than BERT-Base (both 110M parameters, 85.1 vs 82.2).
  • However, on larger model sizes (335M) ELECTRA-Large vs BERT-Large, the scores are only slightly better: GLUE dev 89.5 vs 87.2, and SQuAD 89.6 vs 87.5
Written on February 13, 2023