michael-guenther commited on
Commit
e604d96
1 Parent(s): 035ce47

Upload evaluate_model.py

Browse files
Files changed (1) hide show
  1. evaluate_model.py +118 -0
evaluate_model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for evaluating Jina Embedding Models on the MTEB benchmark.
3
+
4
+ This script is based on the MTEB example:
5
+ https://github.com/embeddings-benchmark/mteb/blob/main/scripts/run_mteb_english.py
6
+ """
7
+
8
+ import logging
9
+
10
+ from mteb import MTEB
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+ logger = logging.getLogger("main")
16
+
17
+ TASK_LIST_CLASSIFICATION = [
18
+ "AmazonCounterfactualClassification",
19
+ "AmazonPolarityClassification",
20
+ "AmazonReviewsClassification",
21
+ "Banking77Classification",
22
+ "EmotionClassification",
23
+ "ImdbClassification",
24
+ "MassiveIntentClassification",
25
+ "MassiveScenarioClassification",
26
+ "MTOPDomainClassification",
27
+ "MTOPIntentClassification",
28
+ "ToxicConversationsClassification",
29
+ "TweetSentimentExtractionClassification",
30
+ ]
31
+
32
+ TASK_LIST_CLUSTERING = [
33
+ "ArxivClusteringP2P",
34
+ "ArxivClusteringS2S",
35
+ "BiorxivClusteringP2P",
36
+ "BiorxivClusteringS2S",
37
+ "MedrxivClusteringP2P",
38
+ "MedrxivClusteringS2S",
39
+ "RedditClustering",
40
+ "RedditClusteringP2P",
41
+ "StackExchangeClustering",
42
+ "StackExchangeClusteringP2P",
43
+ "TwentyNewsgroupsClustering",
44
+ ]
45
+
46
+ TASK_LIST_PAIR_CLASSIFICATION = [
47
+ "SprintDuplicateQuestions",
48
+ "TwitterSemEval2015",
49
+ "TwitterURLCorpus",
50
+ ]
51
+
52
+ TASK_LIST_RERANKING = [
53
+ "AskUbuntuDupQuestions",
54
+ "MindSmallReranking",
55
+ "SciDocsRR",
56
+ "StackOverflowDupQuestions",
57
+ ]
58
+
59
+ TASK_LIST_RETRIEVAL = [
60
+ "ArguAna",
61
+ "ClimateFEVER",
62
+ "CQADupstackAndroidRetrieval",
63
+ "CQADupstackEnglishRetrieval",
64
+ "CQADupstackGamingRetrieval",
65
+ "CQADupstackGisRetrieval",
66
+ "CQADupstackMathematicaRetrieval",
67
+ "CQADupstackPhysicsRetrieval",
68
+ "CQADupstackProgrammersRetrieval",
69
+ "CQADupstackStatsRetrieval",
70
+ "CQADupstackTexRetrieval",
71
+ "CQADupstackUnixRetrieval",
72
+ "CQADupstackWebmastersRetrieval",
73
+ "CQADupstackWordpressRetrieval",
74
+ "DBPedia",
75
+ "FEVER",
76
+ "FiQA2018",
77
+ "HotpotQA",
78
+ "MSMARCO",
79
+ "NFCorpus",
80
+ "NQ",
81
+ "QuoraRetrieval",
82
+ "SCIDOCS",
83
+ "SciFact",
84
+ "Touche2020",
85
+ "TRECCOVID",
86
+ ]
87
+
88
+ TASK_LIST_STS = [
89
+ "BIOSSES",
90
+ "SICK-R",
91
+ "STS12",
92
+ "STS13",
93
+ "STS14",
94
+ "STS15",
95
+ "STS16",
96
+ "STS17",
97
+ "STS22",
98
+ "STSBenchmark",
99
+ "SummEval",
100
+ ]
101
+
102
+ TASK_LIST = (
103
+ TASK_LIST_CLASSIFICATION
104
+ + TASK_LIST_CLUSTERING
105
+ + TASK_LIST_PAIR_CLASSIFICATION
106
+ + TASK_LIST_RERANKING
107
+ + TASK_LIST_RETRIEVAL
108
+ + TASK_LIST_STS
109
+ )
110
+
111
+ model_name = "jinaai/jina-embedding-b-en-v1"
112
+ model = SentenceTransformer(model_name)
113
+
114
+ for task in TASK_LIST:
115
+ logger.info(f"Running task: {task}")
116
+ eval_splits = ["dev"] if task == "MSMARCO" else ["test"]
117
+ evaluation = MTEB(tasks=[task], task_langs=["en"]) # Remove "en" for running all languages
118
+ evaluation.run(model, output_folder=f"results/{model_name}", eval_splits=eval_splits)