lv12's picture
Upload BertForSequenceClassification
5146354 verified
|
raw
history blame
No virus
2.41 kB
---
library_name: transformers
tags:
- cross-encoder
- search
- product-search
base_model: cross-encoder/ms-marco-MiniLM-L-12-v2
---
# Model Descripton
<!-- Provide a quick summary of what the model is/does. -->
Fine tunes a cross encoder on the Amazon ESCI dataset.
# Usage
## Transformers
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch import no_grad
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
queries = [
"adidas shoes",
"adidas sambas",
"girls sandals",
"backpacks",
"shoes",
"mustard blouse"
]
documents = [
"Nike Air Max, with air cushion",
"Adidas Ultraboost, the best boost you can get",
"Women's sandals wide width 9",
"Girl's surf backpack",
"Fresh watermelon, all you can eat",
"Floral yellow dress with frills and lace"
]
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(
queries,
documents,
padding=True,
truncation=True,
return_tensors="pt",
)
model.eval()
with no_grad():
scores = model(**inputs).logits.cpu().detach().numpy()
print(scores)
```
### Sentence Transformers
```python
from sentence_transformers import CrossEncoder
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
queries = [
"adidas shoes",
"adidas sambas",
"girls sandals",
"backpacks",
"shoes",
"mustard blouse"
]
documents = [
"Nike Air Max, with air cushion",
"Adidas Ultraboost, the best boost you can get",
"Women's sandals wide width 9",
"Girl's surf backpack",
"Fresh watermelon, all you can eat",
"Floral yellow dress with frills and lace"
]
model = CrossEncoder(model_name, max_length=512)
scores = model.predict([(q, d) for q, d in zip(queries, documents)])
print(scores)
```
## Training
Trained using `CrossEntropyLoss` using `<query, document>` pairs with `grade` as the label.
```python
from sentence_transformers import InputExample
train_samples = [
InputExample(texts=["query 1", "document 1"], label=0.3),
InputExample(texts=["query 1", "document 2"], label=0.8),
InputExample(texts=["query 2", "document 2"], label=0.1),
]
````