Shitao's picture
Upload folder using huggingface_hub
e1eac59 verified
|
raw
history blame
28.3 kB
metadata
license: gemma
pipeline_tag: text-classification
tags:
  - transformers
  - sentence-transformers
language:
  - multilingual

Reranker

More details please refer to our Github: FlagEmbedding.

Different from embedding model, reranker uses question and document as input and directly output similarity instead of embedding. You can get a relevance score by inputting query and passage to the reranker. And the score can be mapped to a float value in [0,1] by sigmoid function.

Here, we introduce a lightweight reranker bge-reranker-v2.5-gemma2-lightweight, which is a multilingual model trained based on gemma2-9b. By integrating token compression capabilities and layerwise reduction, the model can maintain outstanding performance while saving significant resources.

Our model primarily demonstrates the following capabilities:

  • Lightweight: The model can be made lightweight through token compression, layerwise reduction, or a combination of both.
  • Outstanding performance: The model has achieved new state-of-the-art (SOTA) performance on both BEIR and MIRACL.

We will release a technical report about lightweight reranker soon with more details.

Model List

Model Base model Language layerwise compress ratio compress layers feature
BAAI/bge-reranker-base xlm-roberta-base Chinese and English - - - Lightweight reranker model, easy to deploy, with fast inference.
BAAI/bge-reranker-large xlm-roberta-large Chinese and English - - - Lightweight reranker model, easy to deploy, with fast inference.
BAAI/bge-reranker-v2-m3 bge-m3 Multilingual - - - Lightweight reranker model, possesses strong multilingual capabilities, easy to deploy, with fast inference.
BAAI/bge-reranker-v2-gemma gemma-2b Multilingual - - - Suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities.
BAAI/bge-reranker-v2-minicpm-layerwise MiniCPM-2B-dpo-bf16 Multilingual 8-40 - - Suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference.
BAAI/bge-reranker-v2.5-gemma2-lightweight google/gemma-2-9b Multilingual 8-42 1, 2, 4, 8 [8, 16, 24, 32, 40] Suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers, compress ratio and compress layers for output, facilitating accelerated inference.

You can select the model according your senario and resource.

Usage

Using FlagEmbedding

git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding
pip install -e .

For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )

Get relevance scores (higher scores indicate more relevance):

from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

score = reranker.compute_score(['query', 'passage'])
print(score) # -5.65234375

# You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
score = reranker.compute_score(['query', 'passage'], normalize=True)
print(score) # 0.003497010252573502

scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores) # [-8.1875, 5.26171875]

# You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], normalize=True)
print(scores) # [0.00027803096387751553, 0.9948403768236574]

For LLM-based reranker

from FlagEmbedding import FlagLLMReranker
reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation

score = reranker.compute_score(['query', 'passage'])
print(score)

scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores)

For LLM-based layerwise reranker

from FlagEmbedding import LayerWiseFlagLLMReranker
reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation

score = reranker.compute_score(['query', 'passage'], cutoff_layers=[28]) # Adjusting 'cutoff_layers' to pick which layers are used for computing the score.
print(score)

scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], cutoff_layers=[28])
print(scores)

For LLM-based lightweight reranker

from FlagEmbedding import LightWeightFlagLLMReranker
reranker = LightWeightFlagLLMReranker('BAAI/bge-reranker-v2.5-gemma2-lightweight', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

score = reranker.compute_score(['query', 'passage'], cutoff_layers=[28], compress_ratio=2, compress_layer=[24, 40]) # Adjusting 'cutoff_layers' to pick which layers are used for computing the score.
print(score)

scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], cutoff_layers=[28], compress_ratio=2, compress_layer=[24, 40])
print(scores)

Using Huggingface transformers

For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )

Get relevance scores (higher scores indicate more relevance):

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    print(scores)

For LLM-based reranker

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
    if prompt is None:
        prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
    sep = "\n"
    prompt_inputs = tokenizer(prompt,
                              return_tensors=None,
                              add_special_tokens=False)['input_ids']
    sep_inputs = tokenizer(sep,
                           return_tensors=None,
                           add_special_tokens=False)['input_ids']
    inputs = []
    for query, passage in pairs:
        query_inputs = tokenizer(f'A: {query}',
                                 return_tensors=None,
                                 add_special_tokens=False,
                                 max_length=max_length * 3 // 4,
                                 truncation=True)
        passage_inputs = tokenizer(f'B: {passage}',
                                   return_tensors=None,
                                   add_special_tokens=False,
                                   max_length=max_length,
                                   truncation=True)
        item = tokenizer.prepare_for_model(
            [tokenizer.bos_token_id] + query_inputs['input_ids'],
            sep_inputs + passage_inputs['input_ids'],
            truncation='only_second',
            max_length=max_length,
            padding=False,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=False
        )
        item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
        item['attention_mask'] = [1] * len(item['input_ids'])
        inputs.append(item)
    return tokenizer.pad(
            inputs,
            padding=True,
            max_length=max_length + len(sep_inputs) + len(prompt_inputs),
            pad_to_multiple_of=8,
            return_tensors='pt',
    )

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-gemma')
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-gemma')
yes_loc = tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = get_inputs(pairs, tokenizer)
    scores = model(**inputs, return_dict=True).logits[:, -1, yes_loc].view(-1, ).float()
    print(scores)

For LLM-based layerwise reranker

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
    if prompt is None:
        prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
    sep = "\n"
    prompt_inputs = tokenizer(prompt,
                              return_tensors=None,
                              add_special_tokens=False)['input_ids']
    sep_inputs = tokenizer(sep,
                           return_tensors=None,
                           add_special_tokens=False)['input_ids']
    inputs = []
    for query, passage in pairs:
        query_inputs = tokenizer(f'A: {query}',
                                 return_tensors=None,
                                 add_special_tokens=False,
                                 max_length=max_length * 3 // 4,
                                 truncation=True)
        passage_inputs = tokenizer(f'B: {passage}',
                                   return_tensors=None,
                                   add_special_tokens=False,
                                   max_length=max_length,
                                   truncation=True)
        item = tokenizer.prepare_for_model(
            [tokenizer.bos_token_id] + query_inputs['input_ids'],
            sep_inputs + passage_inputs['input_ids'],
            truncation='only_second',
            max_length=max_length,
            padding=False,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=False
        )
        item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
        item['attention_mask'] = [1] * len(item['input_ids'])
        inputs.append(item)
    return tokenizer.pad(
            inputs,
            padding=True,
            max_length=max_length + len(sep_inputs) + len(prompt_inputs),
            pad_to_multiple_of=8,
            return_tensors='pt',
    )

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to('cuda')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = get_inputs(pairs, tokenizer).to(model.device)
    all_scores = model(**inputs, return_dict=True, cutoff_layers=[28])
    all_scores = [scores[:, -1].view(-1, ).float() for scores in all_scores[0]]
    print(all_scores)

For LLM-based lightweight reranker

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def last_logit_pool(logits: torch.Tensor,
                    attention_mask: torch.Tensor) -> torch.Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return logits[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = logits.shape[0]
        return torch.stack([logits[i, sequence_lengths[i]] for i in range(batch_size)], dim=0)

def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
    if prompt is None:
        prompt = "Predict whether passage B contains an answer to query A."
    sep = "\n"
    prompt_inputs = tokenizer(prompt,
                              return_tensors=None,
                              add_special_tokens=False)['input_ids']
    sep_inputs = tokenizer(sep,
                           return_tensors=None,
                           add_special_tokens=False)['input_ids']
    inputs = []
    query_lengths = []
    prompt_lengths = []
    for query, passage in pairs:
        query_inputs = tokenizer(f'A: {query}',
                                 return_tensors=None,
                                 add_special_tokens=False,
                                 max_length=max_length * 3 // 4,
                                 truncation=True)
        passage_inputs = tokenizer(f'B: {passage}',
                                   return_tensors=None,
                                   add_special_tokens=False,
                                   max_length=max_length,
                                   truncation=True)
        item = tokenizer.prepare_for_model(
            [tokenizer.bos_token_id] + query_inputs['input_ids'],
            sep_inputs + passage_inputs['input_ids'],
            truncation='only_second',
            max_length=max_length,
            padding=False,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=False
        )
        item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
        item['attention_mask'] = [1] * len(item['input_ids'])
        inputs.append(item)
        query_lengths.append(len([tokenizer.bos_token_id] + query_inputs['input_ids'] + sep_inputs))
        prompt_lengths.append(len(sep_inputs + prompt_inputs))
        
    return tokenizer.pad(
            inputs,
            padding=True,
            max_length=max_length + len(sep_inputs) + len(prompt_inputs),
            pad_to_multiple_of=8,
            return_tensors='pt',
    ), query_lengths, prompt_lengths

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2.5-gemma2-lightweight', trust_remote_code=True)
tokenizer.padding_side = 'right'
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2.5-gemma2-lightweight', trust_remote_code=True)
model = model.to('cuda')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs, query_lengths, prompt_lengths = get_inputs(pairs, tokenizer)
    inputs = inputs.to(model.device)
    outputs = model(**inputs,
                    return_dict=True,
                    cutoff_layers=[28],
                    compress_ratio=2,
                    compress_layer=[24, 40],
                    query_lengths=query_lengths,
                    prompt_lengths=prompt_lengths)
    scores = []
    for i in range(len(outputs.logits)):
        logits = last_logit_pool(outputs.logits[i], outputs.attention_masks[i])
        scores.append(logits.cpu().float().tolist())
    print(scores)

Evaluation

  • BEIR:
BEIR bge-large-en-v1.5 Bge-rearanker v2 m3 jina-reranker-v2-base-multilingual bge-reranker-v2-gemma bge-reranker-v2.5-gemma2-lightweight bge-reranker-v2.5-gemma2-lightweight
Save Flops - - - - 60% 0
ArguAna 63.54 37.7 52.23 78.68 86.04 86.16
ClimateFEVER 36.49 37.99 34.65 39.07 48.41 48.48
CQA 42.23 38.24 40.21 45.85 49.18 48.9
DBPedia 44.16 48.15 49.31 49.92 51.98 52.11
FEVER 87.17 90.15 92.44 90.15 94.71 94.69
FiQA2018 44.97 49.32 45.88 49.32 60.48 60.95
HotpotQA 74.11 84.51 81.81 86.15 87.84 87.89
MSMARCO 42.48 47.79 47.83 48.07 47.23 47.26
NFCorpus 38.12 34.85 37.73 39.73 41.4 41.64
NQ 55.04 69.37 67.35 72.6 75.37 75.58
QuoraRetrieval 89.06 89.13 87.81 90.37 91.25 91.18
SCIDOCS 22.62 18.25 20.21 21.65 23.71 23.87
SciFact 74.64 73.08 76.93 77.22 80.5 80.38
Touche2020 25.08 35.68 32.45 35.68 30.64 31.09
TRECCOVID 74.89 83.39 80.89 85.51 84.26 84.85
Mean 54.31 55.36 56.52 60.71 63.1 63.67
BEIR e5-mistral-7b-instruct bge-reranker-v2-gemma bge-reranker-v2.5-gemma-lightweight bge-reranker-v2.5-gemma-lightweight
Save Flops - - 60% 0
ArguAna 61.8 79.05 86.02 86.58
ClimateFEVER 38.37 37.66 47.27 47.13
CQA 42.97 46.16 49.06 49.53
DBPedia 48.84 50.77 52.45 52.87
FEVER 87.82 91.36 94.85 95.19
FiQA2018 56.58 50.96 58.81 61.19
HotpotQA 75.72 86.99 88.49 88.82
MSMARCO 43.06 48.35 47.65 47.4
NFCorpus 38.58 39.25 42.28 42.17
NQ 63.56 73.44 75 76.28
QuoraRetrieval 89.59 90.44 91.09 91.18
SCIDOCS 16.3 20.77 22.2 22.69
SciFact 76.26 77.78 79.94 80.98
Touche2020 26.24 35.79 28.69 31.17
TRECCOVID 87.07 88.13 86.61 87.36
Mean 56.85 61.13 63.36 64.04
  • MIRACL:
MIRACL (dev, nDCG@10) Average (18) save flops ar bn en es fa fi fr hi id ja ko ru sw te th zh de yo
bge-m3 (Dense) 69.2 - 78.4 80.0 56.9 56.1 60.9 78.6 58.3 59.5 56.1 72.8 69.9 70.1 78.7 86.2 82.6 62.7 56.7 81.8
jina-reranker-v2-base-multilingual 69.6 - 73.4 81.9 58.9 58.6 60.5 77.2 56.1 62.7 59.6 72.7 74.0 67.1 78.1 85.8 81.2 63.0 58.2 84.2
bge-reranker-v2-m3 74.4 - 81.7 84.6 63.5 64.4 65.7 82.4 63.7 68.5 62.7 80.0 73.8 76.9 82.3 89.4 85.3 65.2 62.7 87.4
bge-reranker-v2-gemma 75.0 - 82.3 85.0 66.6 65.3 65.5 82.6 65.4 69.4 61.2 79.7 75.1 78.3 81.8 89.6 86.1 66.8 64.0 85.9
bge-reranker-v2.5-gemma2-lightweight 77.1 60% 82.5 87.8 68.6 67.6 67.5 82.8 68.5 71.4 63.8 82.8 75.9 79.8 84.8 90.8 88.1 69.9 65.8 89.6
bge-reranker-v2.5-gemma-lightweight 77.3 0 82.8 87.6 69.3 67.8 67.4 83.3 68.5 71.3 63.8 83.6 75.7 80.1 85.1 90.8 88.7 69.9 65.6 89.8

Citation

If you find this repository useful, please consider giving a star and citation

@misc{li2023making,
      title={Making Large Language Models A Better Foundation For Dense Retrieval}, 
      author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
      year={2023},
      eprint={2312.15503},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{chen2024bge,
      title={BGE M3-Embedding: Multi-Lingual, Multi-Functionality, Multi-Granularity Text Embeddings Through Self-Knowledge Distillation}, 
      author={Jianlv Chen and Shitao Xiao and Peitian Zhang and Kun Luo and Defu Lian and Zheng Liu},
      year={2024},
      eprint={2402.03216},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}