Transformers
Safetensors
bert
Inference Endpoints
File size: 5,129 Bytes
12c676c
 
 
 
 
 
 
 
 
 
 
 
 
2cb7faf
 
 
 
 
 
12c676c
 
 
 
 
a844455
59c6e78
12c676c
a844455
 
 
 
 
 
 
 
12c676c
 
 
 
 
 
 
 
 
 
 
bbc753b
12c676c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
---
license: mit
datasets:
- philipphager/baidu-ultr-pretrain
- philipphager/baidu-ultr_uva-mlm-ctr
metrics:
- log-likelihood
- dcg@1
- dcg@3
- dcg@5
- dcg@10
- ndcg@10
- mrr@10
co2_eq_emissions:
  emissions: 2090
  source: "Calculated using the [ML CO2 impact calculator](https://mlco2.github.io/impact/#compute), training for 4 x 45 hours with a carbon efficiency of 0.029 kg/kWh. You can inspect the carbon efficiency of the French national grid provider here: https://www.rte-france.com/eco2mix/les-emissions-de-co2-par-kwh-produit-en-france"
  training_type: "Pre-training"
  geographical_location: "Grenoble, France"
  hardware_used: "4 NVIDIA H100-80GB GPUs"
---

# Two Tower MonoBERT trained on Baidu-ULTR
A flax-based MonoBERT cross encoder trained on the [Baidu-ULTR](https://arxiv.org/abs/2207.03051) dataset with an **additivie two tower architecture** as suggested by [Yan et al](https://research.google/pubs/revisiting-two-tower-models-for-unbiased-learning-to-rank/). Similar to a position-based click model (PBM), a two tower model jointly learns item relevance (with a BERT model) and position bias (in our case using a single embedding per rank). For more info, [read our paper](https://arxiv.org/abs/2404.02543) and [find the code for this model here](https://github.com/philipphager/baidu-bert-model).

## Test Results on Baidu-ULTR
Ranking performance is measured in DCG, nDCG, and MRR on expert annotations (6,985 queries). Click prediction performance is measured in log-likelihood on one test partition of user clicks (≈297k queries).

| Model                                                                                          | Log-likelihood | DCG@1 | DCG@3 | DCG@5 | DCG@10 | nDCG@10 | MRR@10 |
|------------------------------------------------------------------------------------------------|----------------|-------|-------|-------|--------|---------|--------|
| [Pointwise Naive](https://huggingface.co/philipphager/baidu-ultr_uva-bert_naive-pointwise)     | 0.227          | 1.641 | 3.462 | 4.752 | 7.251  | 0.357   | 0.609  |
| [Pointwise Two-Tower](https://huggingface.co/philipphager/baidu-ultr_uva-bert_twotower)        | 0.218          | 1.629 | 3.471 | 4.822 | 7.456  | 0.367   | 0.607  |
| [Pointwise IPS](https://huggingface.co/philipphager/baidu-ultr_uva-bert_ips-pointwise)         | 0.222          | 1.295 | 2.811 | 3.977 | 6.296  | 0.307   | 0.534  |
| [Listwise Naive](https://huggingface.co/philipphager/baidu-ultr_uva-bert_naive-listwise)       | -              | 1.947 | 4.108 | 5.614 | 8.478  | 0.405   | 0.639  |
| [Listwise IPS](https://huggingface.co/philipphager/baidu-ultr_uva-bert_ips-listwise)           | -              | 1.671 | 3.530 | 4.873 | 7.450  | 0.361   | 0.603  |
| [Listwise DLA](https://huggingface.co/philipphager/baidu-ultr_uva-bert_dla)                    | -              | 1.796 | 3.730 | 5.125 | 7.802  | 0.377   | 0.615  |


## Usage
Here is an example of downloading the model and calling it for inference on a mock batch of input data. For more details on how to use the model on the Baidu-ULTR dataset, take a look at our [training](https://github.com/philipphager/baidu-bert-model/blob/main/main.py) and [evaluation scripts](https://github.com/philipphager/baidu-bert-model/blob/main/eval.py) in our code repository.

```Python
import jax.numpy as jnp

from src.model import PBMCrossEncoder

model = PBMCrossEncoder.from_pretrained(
    "philipphager/baidu-ultr_uva-bert_twotower",
)

# Mock batch following Baidu-ULTR with 4 documents, each with 8 tokens
batch = {
    # Query_id for each document
    "query_id": jnp.array([1, 1, 1, 1]),
    # Document position in SERP
    "positions": jnp.array([1, 2, 3, 4]),
    # Token ids for: [CLS] Query [SEP] Document
    "tokens": jnp.array([
        [2, 21448, 21874, 21436, 1, 20206, 4012, 2860],
        [2, 21448, 21874, 21436, 1, 16794, 4522, 2082],
        [2, 21448, 21874, 21436, 1, 20206, 10082, 9773],
        [2, 21448, 21874, 21436, 1, 2618, 8520, 2860],
  ]),
    # Specify if a token id belongs to the query (0) or document (1)
    "token_types": jnp.array([
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1],
    ]),
    # Marks if a token should be attended to (True) or ignored, e.g., padding tokens (False):
    "attention_mask": jnp.array([
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
    ]),
}

outputs = model(batch, train=False)
print(outputs)
```

## Reference
```
@inproceedings{Hager2024BaiduULTR,
  author = {Philipp Hager and Romain Deffayet and Jean-Michel Renders and Onno Zoeter and Maarten de Rijke},
  title = {Unbiased Learning to Rank Meets Reality: Lessons from Baidu’s Large-Scale Search Dataset},
  booktitle = {Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR`24)},
  organization = {ACM},
  year = {2024},
}
```