philipphager
commited on
Commit
•
797f9d4
1
Parent(s):
f770d24
Update README.md
Browse files
README.md
CHANGED
@@ -16,14 +16,46 @@ metrics:
|
|
16 |
A flax-based MonoBERT cross encoder trained on the Baidu-ULTR dataset with a **listwise softmax cross-entropy loss on clicks**. The loss is called "naive" as we use user clicks as a signal of relevance without any additional position bias correction. For more info, read our paper here.
|
17 |
|
18 |
## Usage
|
19 |
-
```
|
|
|
|
|
20 |
from src.model import ListwiseCrossEncoder
|
21 |
|
22 |
model = ListwiseCrossEncoder.from_pretrained(
|
23 |
"philipphager/baidu-ultr_uva-bert_naive-listwise",
|
24 |
)
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
```
|
28 |
|
29 |
## Test Results on Baidu-ULTR Expert Annotations
|
|
|
16 |
A flax-based MonoBERT cross encoder trained on the Baidu-ULTR dataset with a **listwise softmax cross-entropy loss on clicks**. The loss is called "naive" as we use user clicks as a signal of relevance without any additional position bias correction. For more info, read our paper here.
|
17 |
|
18 |
## Usage
|
19 |
+
```Python
|
20 |
+
import jax.numpy as jnp
|
21 |
+
|
22 |
from src.model import ListwiseCrossEncoder
|
23 |
|
24 |
model = ListwiseCrossEncoder.from_pretrained(
|
25 |
"philipphager/baidu-ultr_uva-bert_naive-listwise",
|
26 |
)
|
27 |
|
28 |
+
# Mock batch from Baidu-ULTR with 4 documents, each with 32 tokens
|
29 |
+
batch = {
|
30 |
+
# Query_id for each document
|
31 |
+
"query_id": jnp.array([1, 1, 1, 1]),
|
32 |
+
# Document position in SERP
|
33 |
+
"positions": jnp.array([1, 2, 3, 4]),
|
34 |
+
# Token ids for each query/document combination
|
35 |
+
"tokens": jnp.array([
|
36 |
+
[2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 20206, 4012, 2860, 5996, 9526, 10966, 11858, 15035, 2677, 21446, 21401, 21401, 10092, 250, 8547, 7936, 2677, 1, 21874],
|
37 |
+
[2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 16794, 4522, 2082, 2860, 16923, 3186, 15035, 2677, 21446, 21401, 21401, 10092, 21448, 19087, 480, 21449, 21401, 8747, 21436],
|
38 |
+
[2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 20206, 10082, 9773, 6164, 8825, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 10092, 21455, 4516, 2049, 20167, 15035],
|
39 |
+
[2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 2618, 8520, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 10092, 21455, 2618, 8520, 2860, 5996, 9526, 21446, 21401],
|
40 |
+
]),
|
41 |
+
# Specify if a token id belongs to the query (0) or document (1)
|
42 |
+
"token_types": jnp.array([
|
43 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
44 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
45 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
46 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
47 |
+
]),
|
48 |
+
# Marks if a token should be attended to (True) or ignored, e.g., padding tokens (False):
|
49 |
+
"attention_mask": jnp.array([
|
50 |
+
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
|
51 |
+
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
|
52 |
+
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
|
53 |
+
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
|
54 |
+
]),
|
55 |
+
}
|
56 |
+
|
57 |
+
outputs = model(batch)
|
58 |
+
print(outputs)
|
59 |
```
|
60 |
|
61 |
## Test Results on Baidu-ULTR Expert Annotations
|