RA-DIT - Retrieval-Augmented Dual Instruction Tuning
Recent retrieval augmented generation (RAG) models tend to perform joint training of the retriever and generator. In contrast, a recent paper “RA-DIT: retrievial-augmented dual instruction tuning” by Meta on 10/2/2023 proposed a lightweight approach. Here, the authors performed two distinct fine-tuning steps: (1) update a pretrained LLM to use retrieved information, (2) update the retriever to return results as preferred by the LLM. I.e. the RA-DIT approach separately fine-tunes the LLM and the retriever.
Fine-Tuning Datasets
The authors used the datasets shown in the following Figure (from the RA-DIT paper) across 5 distinct categories (dialgoue, open-domain QA, reading comprehension, summarization, chain-of-thought reasoning) for fine-tuning the pre-trained language model. For the retriever, the authors used the QA datasets in the collection, and also FreebaseQA and MS-MARCO
- $D_{L}$ and $D_{R}$ means that they are used for fine-tuning the language model and retriever respectively.
LLM Fine-Tuning: Retrieval Augmented Language Model Fine-Tuning
Each example $(x_i, y_i) \in D_{L}$ consists of an instruction $x_i$ and output $y_i$.
- Based on the $x_i$, we first retrieve the top-$k$ relevant text chunks $C_{i} \subset C$.
- For each retrieved chunk $c_{i,j} \in C_{i}$, $c_{i,j}$ is prepended to $x_i$, resulting in ${ c_{i,j} \circ x_{i}, y_{i} }$, for $j=1 \ldots k$.
- The loss is only calculated over the output tokens $y_{i}$. The LLM is fine-tuned using the next token prediction objective: \(L(D_{L}) = - \sum_{i} \sum_{j} \text{log} P_{LM}(y_{i} | c_{i,j} \circ x_{i})\)
Retriever Fine-Tuning
The authors used a dual-encoder based retriever, i.e. there is a separate encoder $d$ which maps a document chunk $c$ to embeddings $E_{d}(c)$, and a separate encoder $q$ which maps a query $q$ to embeddings $E_{q}(q)$. The top-$k$ relevant text chunks for $q$ are retrieved based on their dot-product: $s(q, c) = E_{q}(q) \cdot E_{d}(c)$.
Here, the authors performed LM-supervised retrieval (LSR) fine-tuning, which leverages the LLM to provide supervison signals. For a training sample $(x, y)$ in the retriever fine-tuning dataset: \(p_{LSR}(c|x, y) \approx \frac{\text{exp}(p_{LM}(y | c \circ x) / \tau)}{\sum_{c' \in C'} \text{exp}(p_{LM}(y | c' \circ x)/ \tau)}\)
- Where $\tau$ is a temperature parameter, and $C’ \subset C$ denotes the top-$k$ retrieved chunks for $x$. The goal of LSR training is so that the retriever can learn to assign higher scores to chunks that can improve the LLM’s likelihood of generating the answer $y$.
- In practice, the authors only update the query encoder portion of the retriever.
Results
The paper shows that they achieved SOTA results on datasets: MMLU, Natural Questions, TriviaQA, and a subset of tasks in the KILT benchmark.