Permutation Language Modeling

The Permutation Language Modeling (PLM) pretraining objective was introduced in the paper “XLNet: Generalized Autoregressive Pretraining for Language Understanding”, published in June 2019. It tries to address the independence and noise issues that arise from the masked language modeling (MLM) pretraining objective of BERT, while retaining its advantage of utilizing bidirectional context. The core idea is to sample many permutations of the same input sequence, and train on these in an autoregressive manner. In expectation, each token will then have learnt from all other tokens in the input context, while avoiding the usage of masking.

Masked Language Modeling (MLM) of BERT

BERT tries to predict the masked tokens from everything that is not masked in the context:

  • $\mathbf{x}$: input text sequence $[x_1, \ldots, x_T]$
  • $\mathbf{\bar{x}}$: the masked tokens.
  • $\mathbf{\hat{x}}$: the corrupted version of $\mathbf{x}$ via masking.

The pretraining objective is to reconstruct the masked tokens $\mathbf{\bar{x}}$ from the corrupted sequence $\mathbf{\hat{x}}$: \(\text{max}_{\theta} \text{ log } p_{\theta}(\mathbf{\bar{x}}|\mathbf{\hat{x}}) \approx \sum_{t=1}^{T} m_t \text{ log } p_{\theta}(x_t | \mathbf{\hat{x}})\) Where $m_t = 1$ indicates that $x_t$ is masked. Although BERT has access to bidirectional context, it suffers from the following issues:

  • Independence assumption: BERT factorizes the joint probability $p(\mathbf{\bar{x}}|\mathbf{\hat{x}})$ based on an independence assumption that all masked tokens $\mathbf{\bar{x}}$ are separately reconstructed.
  • Input noise from masking: The input to BERT contains artificial symbols $[\text{MASK}]$ which does not occur in downstrream tasks. In contrast, autoregressive modelling does not do input corruption and thus does not suffer form this issue.

Permutation Language Modeling

For a sequence $\mathbf{x}$ of length $T$, there are $T!$ different permutation orders or factorization. If we permute, then ask our model to learn in an autoregressive manner, then in expectation, the model in modeling $x_t$, will have seen every possible token $x_i \ne x_t$ in the sequence, hence capturing bidirectional contexts.

That is, by sampling multiple permutations of the same input sequence, each token will have seen many different parts of the context/sequence. In expectation, we will have seen all the tokens in the sequence.

Note that the original sequence is kept, and:

  • The positional encodings corresponding to the original (unpermuted) sequence is used.
  • Rely on a proper attention mask in Transformers to achieve permutation of the factorization order.

As an example, consider we have an input sequence of four tokens: $[1, 2, 3, 4]$, and we produce the following different permutations. In learning the representation of token “3” on each permutation:

  • $3 -> 2 -> 4 -> 1$: does not attend to any token.
  • $2 -> 4 -> 3 -> 1$: attends to tokens 2 and 4.
  • $1 -> 4 -> 2 -> 3$: attends to all other tokens 1, 4, 2.
  • $4 -> 3 -> 1 -> 2$: attends to token 4.
Written on February 21, 2023