lv12's picture
Upload BertForSequenceClassification
5146354 verified
|
raw
history blame
2.41 kB
metadata
library_name: transformers
tags:
  - cross-encoder
  - search
  - product-search
base_model: cross-encoder/ms-marco-MiniLM-L-12-v2

Model Descripton

Fine tunes a cross encoder on the Amazon ESCI dataset.

Usage

Transformers

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch import no_grad

model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"

queries = [
    "adidas shoes",
    "adidas sambas",
    "girls sandals",
    "backpacks",
    "shoes", 
    "mustard blouse"
]
documents =  [
        "Nike Air Max, with air cushion",
        "Adidas Ultraboost, the best boost you can get",
        "Women's sandals wide width 9",
        "Girl's surf backpack",
        "Fresh watermelon, all you can eat",
        "Floral yellow dress with frills and lace"
    ]

model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(
    queries,
    documents,
    padding=True,
    truncation=True,
    return_tensors="pt",
)

model.eval()
with no_grad():
    scores = model(**inputs).logits.cpu().detach().numpy()
    print(scores)

Sentence Transformers

from sentence_transformers import CrossEncoder

model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"

queries = [
    "adidas shoes",
    "adidas sambas",
    "girls sandals",
    "backpacks",
    "shoes", 
    "mustard blouse"
]
documents =  [
        "Nike Air Max, with air cushion",
        "Adidas Ultraboost, the best boost you can get",
        "Women's sandals wide width 9",
        "Girl's surf backpack",
        "Fresh watermelon, all you can eat",
        "Floral yellow dress with frills and lace"
    ]
model = CrossEncoder(model_name, max_length=512)
scores = model.predict([(q, d) for q, d in zip(queries, documents)])
print(scores)

Training

Trained using CrossEntropyLoss using <query, document> pairs with grade as the label.

from sentence_transformers import InputExample

train_samples = [
    InputExample(texts=["query 1", "document 1"], label=0.3),
    InputExample(texts=["query 1", "document 2"], label=0.8),
    InputExample(texts=["query 2", "document 2"], label=0.1),
]