lv12's picture
Update README.md
20c84a2 verified
|
raw
history blame
No virus
2.58 kB
---
library_name: transformers
base_model: "cross-encoder/ms-marco-MiniLM-L-12-v2"
model-index:
- name: esci-ms-marco-MiniLM-L-12-v2
results:
- task:
type: Reranking
metrics:
- type: mrr@10
value: 91.74
- type: ndcg@10
value: 84.83
tags: ["cross-encoder", "search", "product-search"]
---
# Model Descripton
<!-- Provide a quick summary of what the model is/does. -->
Fine tunes a cross encoder on the Amazon ESCI dataset.
# Usage
## Transformers
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch import no_grad
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
queries = [
"adidas shoes",
"adidas sambas",
"girls sandals",
"backpacks",
"shoes",
"mustard blouse"
]
documents = [
"Nike Air Max, with air cushion",
"Adidas Ultraboost, the best boost you can get",
"Women's sandals wide width 9",
"Girl's surf backpack",
"Fresh watermelon, all you can eat",
"Floral yellow dress with frills and lace"
]
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(
queries,
documents,
padding=True,
truncation=True,
return_tensors="pt",
)
model.eval()
with no_grad():
scores = model(**inputs).logits.cpu().detach().numpy()
print(scores)
```
### Sentence Transformers
```python
from sentence_transformers import CrossEncoder
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
queries = [
"adidas shoes",
"adidas sambas",
"girls sandals",
"backpacks",
"shoes",
"mustard blouse"
]
documents = [
"Nike Air Max, with air cushion",
"Adidas Ultraboost, the best boost you can get",
"Women's sandals wide width 9",
"Girl's surf backpack",
"Fresh watermelon, all you can eat",
"Floral yellow dress with frills and lace"
]
model = CrossEncoder(model_name, max_length=512)
scores = model.predict([(q, d) for q, d in zip(queries, documents)])
print(scores)
```
## Training
Trained using MSELoss using `<query, document>` pairs with `grade` as the label.
```python
from sentence_transformers import InputExample
train_samples = [
InputExample(texts=["query 1", "document 1"], label=0.3),
InputExample(texts=["query 1", "document 2"], label=0.8),
InputExample(texts=["query 2", "document 2"], label=0.1),
]
````