lv12 commited on
Commit
ccad832
1 Parent(s): 7cbebc1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -1
README.md CHANGED
@@ -22,11 +22,14 @@ Fine tunes a cross encoder on the Amazon ESCI dataset.
22
 
23
  # Usage
24
 
 
 
25
  <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
26
 
27
 
28
  ```python
29
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
30
 
31
  model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
32
 
@@ -56,7 +59,38 @@ inputs = tokenizer(
56
  truncation=True,
57
  return_tensors="pt",
58
  )
59
- scores = model(**inputs).logits.cpu().detach().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  print(scores)
61
  ```
62
 
 
22
 
23
  # Usage
24
 
25
+ ## Transformers
26
+
27
  <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
28
 
29
 
30
  ```python
31
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
32
+ from torch import no_grad
33
 
34
  model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
35
 
 
59
  truncation=True,
60
  return_tensors="pt",
61
  )
62
+
63
+ model.eval()
64
+ with no_grad():
65
+ scores = model(**inputs).logits.cpu().detach().numpy()
66
+ print(scores)
67
+ ```
68
+
69
+ ### Sentence Transformers
70
+
71
+ ```python
72
+ from sentence_transformers import CrossEncoder
73
+
74
+ model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
75
+
76
+ queries = [
77
+ "adidas shoes",
78
+ "adidas sambas",
79
+ "girls sandals",
80
+ "backpacks",
81
+ "shoes",
82
+ "mustard blouse"
83
+ ]
84
+ documents = [
85
+ "Nike Air Max, with air cushion",
86
+ "Adidas Ultraboost, the best boost you can get",
87
+ "Women's sandals wide width 9",
88
+ "Girl's surf backpack",
89
+ "Fresh watermelon, all you can eat",
90
+ "Floral yellow dress with frills and lace"
91
+ ]
92
+ model = CrossEncoder(model_name, max_length=512)
93
+ scores = model.predict([(q, d) for q, d in zip(queries, documents)])
94
  print(scores)
95
  ```
96