Transformers
Safetensors
bert
Inference Endpoints
philipphager commited on
Commit
797f9d4
1 Parent(s): f770d24

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -2
README.md CHANGED
@@ -16,14 +16,46 @@ metrics:
16
  A flax-based MonoBERT cross encoder trained on the Baidu-ULTR dataset with a **listwise softmax cross-entropy loss on clicks**. The loss is called "naive" as we use user clicks as a signal of relevance without any additional position bias correction. For more info, read our paper here.
17
 
18
  ## Usage
19
- ```
 
 
20
  from src.model import ListwiseCrossEncoder
21
 
22
  model = ListwiseCrossEncoder.from_pretrained(
23
  "philipphager/baidu-ultr_uva-bert_naive-listwise",
24
  )
25
 
26
- model(batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ```
28
 
29
  ## Test Results on Baidu-ULTR Expert Annotations
 
16
  A flax-based MonoBERT cross encoder trained on the Baidu-ULTR dataset with a **listwise softmax cross-entropy loss on clicks**. The loss is called "naive" as we use user clicks as a signal of relevance without any additional position bias correction. For more info, read our paper here.
17
 
18
  ## Usage
19
+ ```Python
20
+ import jax.numpy as jnp
21
+
22
  from src.model import ListwiseCrossEncoder
23
 
24
  model = ListwiseCrossEncoder.from_pretrained(
25
  "philipphager/baidu-ultr_uva-bert_naive-listwise",
26
  )
27
 
28
+ # Mock batch from Baidu-ULTR with 4 documents, each with 32 tokens
29
+ batch = {
30
+ # Query_id for each document
31
+ "query_id": jnp.array([1, 1, 1, 1]),
32
+ # Document position in SERP
33
+ "positions": jnp.array([1, 2, 3, 4]),
34
+ # Token ids for each query/document combination
35
+ "tokens": jnp.array([
36
+ [2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 20206, 4012, 2860, 5996, 9526, 10966, 11858, 15035, 2677, 21446, 21401, 21401, 10092, 250, 8547, 7936, 2677, 1, 21874],
37
+ [2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 16794, 4522, 2082, 2860, 16923, 3186, 15035, 2677, 21446, 21401, 21401, 10092, 21448, 19087, 480, 21449, 21401, 8747, 21436],
38
+ [2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 20206, 10082, 9773, 6164, 8825, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 10092, 21455, 4516, 2049, 20167, 15035],
39
+ [2, 21448, 21874, 21436, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 1, 2618, 8520, 2860, 5996, 9526, 15035, 2677, 21446, 21401, 21401, 10092, 21455, 2618, 8520, 2860, 5996, 9526, 21446, 21401],
40
+ ]),
41
+ # Specify if a token id belongs to the query (0) or document (1)
42
+ "token_types": jnp.array([
43
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
44
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
45
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
46
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
47
+ ]),
48
+ # Marks if a token should be attended to (True) or ignored, e.g., padding tokens (False):
49
+ "attention_mask": jnp.array([
50
+ [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
51
+ [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
52
+ [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
53
+ [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
54
+ ]),
55
+ }
56
+
57
+ outputs = model(batch)
58
+ print(outputs)
59
  ```
60
 
61
  ## Test Results on Baidu-ULTR Expert Annotations