SGPT - GPT Sentence Embeddings
The SGPT paper fine-tunes GPT-style decoder-only models on pairwise sentence datasets, such that they can produce effective sentence embeddings. Related work are Sentence-BERT and the all-mpnet-base-v2
model, which are sentence embedding models based on encoder Transformers. The SGPT paper “SGPT: GPT Sentence Embeddings for Semantic Search” was published in February 2022, and leverged open source GPT-style models: GPT-neo and GPT-J from EleutherAI.
The paper experimented with two architectures:
- Cross-encoders:
- Encode query and document at the same time, e.g. BERT could be used as a cross-encoder by doing:
[CLS]<query>[SEP]<document>[SEP]
and then passing this concatenated text though the Transformer together. - However, each query then requires separate concatenation with each individual document, making inference expensive.
- Encode query and document at the same time, e.g. BERT could be used as a cross-encoder by doing:
- Bi-encoders:
- Example of this is Sentence-BERT, which produces separate query and document representations. The resultant document representations can be cached.
The paper also defined the following notions:
- Asymmetric search: queries and documents are not interchangeable, e.g. finding answers given a question.
- Symmetric search: queries and documents are interchangeable, e.g. finding duplicate questions where both queries and documents are questions.
SGPT Cross-Encoder (SGPT-CE)
Given a query $q$ and a document corpus $D$, would like to find the most likely document $d^*$. This can be expressed as: \(d^* = \text{arg max}_{d \in D}P(d|q) = \text{arg max}_{d \in D} \frac{P(q|d)P(d)}{P(q)} \approx \text{arg max}_{d \in D} P(q|d)\) $P(q|d)$ is modeled as $p(q_i, \ldots, q_n | p_1, \ldots, p_{i-1})$ where the document tokens are embedded as part of a prompt $P$:
- An example of a prompt is:
Documents are searched to find matches with the same content. The document {doc} is a good search result for {query}
. - Note that the query length above is $n - i + 1$. To capture the entire query, document tokens are truncated from the left to fit the model’s maximum sequence length.
- The paper first use BM25 to find the top-$k$ documents, then apply this cross-encoder method to re-rank these $k$ documents.
- The paper conducts experiments with pretrained decoders of size 125M, 1.3B, 2.7B, and 6.1B parameters.
Evaluation
To recap, the candidate document and target query are embedded within a prompt, and the log probability of the target query is then calculated. The candidate document that produces the maximum probability for the query is selected as the predicted best document.
The paper performed a search over 12 prompts using the MSMARCO dataset as provided in BEIR, and then use the best prompt for their SGPT-CE (SGPT-cross-encoder) model. Experiment summary:
- BEIR is a zero-shot evaluation dataset for information retrieval (IR) models.
- Scores are nDCG@10.
- SGPT-CE 6.1B is slighty worse (0.462) than the best performing model BM25+CE (0.476) from BEIR. BM25+CE pretrained BERT that is further fine-tuned on MS-MARCO in a supervised fashion. However, SGPT-CE 6.1B has almost 15x more parameters than BM25+CE.
SGPT Bi-Encoder (SGPT-BE)
In the all-mpnet-base-v2
sentence embedding model, it uses constrastive learning with in-batch negatives and cross-entropy loss. Also, the representation of a sentence is mean pooling over the tokens’ hidden representations. In this SGPT paper, the authors take a very similar approach. The only difference is using a position-weighted mean pooling:
\(v = \sum_{i=1}^{S} w_i h_i\)
\(w_i = \frac{i}{\sum_{i=1}^{S}i}\)
- $S$ is the sequence length, $h_i$ is the $i$th hidden state, $v$ is the resultant query or document embedding.
- The reason for giving later tokens a higher weight is because in an autoregressive model, the latter tokens attend to more tokens, as compared to tokens earlier in the sequence.
The paper then performed supervised contrastive learning with in-batch negatives. Given matching query-doc pairs ${(q_i, d_i)}_{i=1}^{M}$, optimize the cost function: \(J_{CL}(\theta) = \frac{1}{M} \sum_{i=1}^{M} \text{log} \frac{\text{exp}(\tau * \text{sim}(f_{\theta}(q_i), f_{\theta}(d_i)))}{\sum_{j=1}^{M} \text{exp}(\tau * \text{sim}(f_{\theta}(q_i), f_{\theta}(d_j))}\)
- $f_{\theta}$ is position-weighted mean pooling function, $\text{sim(.)}$ is cosine-similarity, $\tau$ is a temperature parameter set to 20 in the paper’s experiments.
The paper fine-tuned only bias parameters while freezing the rest of the model, over the following pairwise datasets. :
- Stanford Natural Language Inference (SNLI) corpus: 570k sentence-pairs manually labeled as entailment, contradiction, and neutral.
- Multi-Genre Natural Language Inference (MNLI) dataset: crowd-sourced collection of 433k sentence pairs annotated with textual entailment information.