|
--- |
|
library_name: transformers |
|
base_model: "cross-encoder/ms-marco-MiniLM-L-12-v2" |
|
model-index: |
|
- name: esci-ms-marco-MiniLM-L-12-v2 |
|
results: |
|
- task: |
|
type: ranking |
|
metrics: |
|
- type: mrr@10 |
|
value: 91.74 |
|
- type: ndcg@10 |
|
value: 84.83 |
|
tags: ["cross-encoder", "search", "product-search"] |
|
--- |
|
|
|
# Model Descripton |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
Fine tunes a cross encoder on the Amazon ESCI dataset. |
|
|
|
# Usage |
|
|
|
## Transformers |
|
|
|
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app --> |
|
|
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from torch import no_grad |
|
|
|
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2" |
|
|
|
queries = [ |
|
"adidas shoes", |
|
"adidas sambas", |
|
"girls sandals", |
|
"backpacks", |
|
"shoes", |
|
"mustard blouse" |
|
] |
|
documents = [ |
|
"Nike Air Max, with air cushion", |
|
"Adidas Ultraboost, the best boost you can get", |
|
"Women's sandals wide width 9", |
|
"Girl's surf backpack", |
|
"Fresh watermelon, all you can eat", |
|
"Floral yellow dress with frills and lace" |
|
] |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
inputs = tokenizer( |
|
queries, |
|
documents, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
model.eval() |
|
with no_grad(): |
|
scores = model(**inputs).logits.cpu().detach().numpy() |
|
print(scores) |
|
``` |
|
|
|
### Sentence Transformers |
|
|
|
```python |
|
from sentence_transformers import CrossEncoder |
|
|
|
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2" |
|
|
|
queries = [ |
|
"adidas shoes", |
|
"adidas sambas", |
|
"girls sandals", |
|
"backpacks", |
|
"shoes", |
|
"mustard blouse" |
|
] |
|
documents = [ |
|
"Nike Air Max, with air cushion", |
|
"Adidas Ultraboost, the best boost you can get", |
|
"Women's sandals wide width 9", |
|
"Girl's surf backpack", |
|
"Fresh watermelon, all you can eat", |
|
"Floral yellow dress with frills and lace" |
|
] |
|
model = CrossEncoder(model_name, max_length=512) |
|
scores = model.predict([(q, d) for q, d in zip(queries, documents)]) |
|
print(scores) |
|
``` |
|
|
|
## Training |
|
|
|
Trained using MSELoss using <query, document> pairs with <grade> as the label. |
|
|
|
```python |
|
from sentence_transformers import InputExample |
|
|
|
train_samples = [ |
|
InputExample(texts=["query 1", "document 1"], label=0.3), |
|
InputExample(texts=["query 1", "document 2"], label=0.8), |
|
InputExample(texts=["query 2", "document 2"], label=0.1), |
|
] |
|
```` |