hellonlp-embedding
Collection
Focus on sentence similarity.
•
4 items
•
Updated
•
1
The following datasets are all in Chinese.
Data | size(train) | size(valid) | size(test) |
---|---|---|---|
ATEC | 62477 | 20000 | 20000 |
BQ | 100000 | 10000 | 10000 |
LCQMC | 238766 | 8802 | 12500 |
PAWSX | 49401 | 2000 | 2000 |
STS-B | 5231 | 1458 | 1361 |
SNLI | 146828 | 2699 | 2618 |
MNLI | 122547 | 2932 | 2397 |
The evaluation dataset is in Chinese, and we used the same language model RoBERTa base on different methods. In addition, considering that the test set of some datasets is small, which may lead to a large deviation in evaluation accuracy, the evaluation data here uses train, valid and test at the same time, and the final evaluation result adopts the weighted average (w-avg) method.
Model | STS-B(w-avg) | ATEC | BQ | LCQMC | PAWSX | Avg. |
---|---|---|---|---|---|---|
BAAI/bge-large-zh | 78.61 | - | - | - | - | - |
BAAI/bge-large-zh-v1.5 | 79.07 | - | - | - | - | - |
hellonlp/simcse-large-zh | 81.32 | - | - | - | - | - |
You can use our model for encoding sentences into embeddings
import torch
from transformers import BertTokenizer
from transformers import BertModel
from sklearn.metrics.pairwise import cosine_similarity
# model
simcse_sup_path = "hellonlp/simcse-roberta-large-zh"
tokenizer = BertTokenizer.from_pretrained(simcse_sup_path)
MODEL = BertModel.from_pretrained(simcse_sup_path)
def get_vector_simcse(sentence):
"""
预测simcse的语义向量。
"""
input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
output = MODEL(input_ids)
return output.last_hidden_state[:, 0].squeeze(0)
embeddings = get_vector_simcse("武汉是一个美丽的城市。")
print(embeddings.shape)
#torch.Size([1024])
You can also compute the cosine similarities between two sentences
def get_similarity_two(sentence1, sentence2):
vec1 = get_vector_simcse(sentence1).tolist()
vec2 = get_vector_simcse(sentence2).tolist()
similarity_list = cosine_similarity([vec1], [vec2]).tolist()[0][0]
return similarity_list
sentence1 = '你好吗'
sentence2 = '你还好吗'
result = get_similarity_two(sentence1,sentence2)
print(result)
#0.848331