REALM - Augment Language Models with a Knowledge Retriever

The Retrieval-Augmented Language Model (REALM) is described in the paper “REALM: Retrieval-Augmented Language Model Pre-Training”, published by Google in Feburary 2020. REALM augments language models with a knowledge retriever, such that during pretraining, fine-tuning, and inference, the language model is able to retrieve and attend over text documents from an external corpus. This has two benefits: (i) without such an external knowledge source, the parameters of the language model is the sole source of all learned knowledge thus requiring larger and larger model size to store increasing more knowledge, (ii) without an external knowledge source, the trained language model is inherently static.

Model Architecture

The premise of REALM is the following:

  • Given an input $x$, learn a distribution $p(y|x)$ over possible outputs $y$. For pretraining, $x$ is a sentence with masked out tokens and the model must predict the identity $y$ of those tokens. For fine-tuning on a QA task, $x$ is a question, and $y$ is the answer.
  • To model $p(y|x)$, REALM decomposes this into 2 steps:
    • Given $x$, first retrieve helpful documents $z$. We model this as $p(z|x)$.
    • Then we generate $y$ using $p(y|z,x)$.

Neural Knowledge Retriever

This models $p(z|x)$ as follows: \(p(z|x) = \frac{\text{exp }f(x,z)}{\sum_{z'} \text{exp }f(x,z')}\) \(f(x,z) = \text{Embed}_{\text{input}}(x)^{T} \text{Embed}_{\text{doc}}(z)\)

  • Relevance score $f(x,z)$ is the inner product of the vector embeddings. The retrieval function $p(z|x)$ is then a softmax over the relevance scores.
  • \(\text{Embed}_{\text{input}}\) and \(\text{Embed}_{\text{doc}}\) are embedding functions that map input $x$ and document $z$ to $d$-dimensional vectors, defined as follows: \(\text{Embed}_{\text{input}}(x) = \mathbf{W}_{input} \text{ BERT}_{\text{CLS}}(\text{join}_{\text{BERT}}(x))\) \(\text{Embed}_{\text{doc}}(z) = \mathbf{W}_{\text{doc}} \text{ BERT}_{\text{CLS}}(\text{join}_{\text{BERT}} (z_{\text{title}}, z_{\text{body}}))\)
    • Where $\mathbf{W}$ is a linear layer to reduce the dimensionality of the input embedding vector, and \(\text{join}_{\text{BERT}}\) is the BERT style concatenation: \(\text{join}_{\text{BERT}} = [\text{CLS}] x_1 [\text{SEP}] x_2 [\text{SEP}]\)

Knowledge Augmented Encoder

This models $p(y|z,x)$ as follows. We first join $x$ and $z$ into a single sequence that we feed into a Transformer (separate from the one used for the Neural knowledge retriever).

For MLM pretraining, we operate on \(\text{BERT}_{\text{MASK}(j)}(\text{join}_{\text{BERT}}(x, z_{\text{body}}))\): i.e. first join, then apply MLM masking.

For QA fine-tuning, the assumption is that the answer $y$ can be found as a contiguous sequence of tokens in some document $z$. The $p(y|z,x)$ is defined as follows: \(p(y|z,x) \propto \sum_{s \in S(z,y)} \text{exp} (\text{MLP}([h_{\text{START}(s)} ; h_{\text{END}(s)}]))\) \(h_{\text{START}(s)} = \text{BERT}_{\text{START}(s)}(\text{join}_{\text{BERT}}(x, z_{\text{body}}))\) \(h_{\text{END}(s)} = \text{BERT}_{\text{END}(s)}(\text{join}_{\text{BERT}}(x, z_{\text{body}}))\)

  • Where \(\text{BERT}_{\text{START}(s)}\) and \(\text{BERT}_{\text{END}(s)}\) denote the start and end tokens embedding vectors, and $S(z,y)$ is the set of spans matching $y$ in document $z$.
Written on January 26, 2023