tomaarsen HF staff commited on
Commit
82dda1b
1 Parent(s): 918cbe4

Create training_nli_matryoshka.py

Browse files
Files changed (1) hide show
  1. training_nli_matryoshka.py +106 -0
training_nli_matryoshka.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Matryoshka test
2
+ from collections import defaultdict
3
+ from typing import Dict
4
+ import datasets
5
+ from datasets import Dataset
6
+ from sentence_transformers import (
7
+ SentenceTransformer,
8
+ SentenceTransformerTrainer,
9
+ losses,
10
+ evaluation,
11
+ TrainingArguments
12
+ )
13
+ from sentence_transformers.models import Transformer, Pooling
14
+
15
+ def to_triplets(dataset):
16
+ premises = defaultdict(dict)
17
+ for sample in dataset:
18
+ premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
19
+ queries = []
20
+ positives = []
21
+ negatives = []
22
+ for premise, sentences in premises.items():
23
+ if 0 in sentences and 2 in sentences:
24
+ queries.append(premise)
25
+ positives.append(sentences[0]) # <- entailment
26
+ negatives.append(sentences[2]) # <- contradiction
27
+ return Dataset.from_dict({
28
+ "anchor": queries,
29
+ "positive": positives,
30
+ "negative": negatives,
31
+ })
32
+
33
+ snli_ds = datasets.load_dataset("snli")
34
+ snli_ds = datasets.DatasetDict({
35
+ "train": to_triplets(snli_ds["train"]),
36
+ "validation": to_triplets(snli_ds["validation"]),
37
+ "test": to_triplets(snli_ds["test"]),
38
+ })
39
+ multi_nli_ds = datasets.load_dataset("multi_nli")
40
+ multi_nli_ds = datasets.DatasetDict({
41
+ "train": to_triplets(multi_nli_ds["train"]),
42
+ "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
43
+ })
44
+
45
+ all_nli_ds = datasets.DatasetDict({
46
+ "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),
47
+ "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
48
+ "test": snli_ds["test"]
49
+ })
50
+
51
+ stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
52
+ stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
53
+
54
+ training_args = TrainingArguments(
55
+ output_dir="checkpoints",
56
+ num_train_epochs=1,
57
+ seed=42,
58
+ per_device_train_batch_size=64,
59
+ per_device_eval_batch_size=64,
60
+ learning_rate=2e-5,
61
+ warmup_ratio=0.1,
62
+ bf16=True,
63
+ logging_steps=10,
64
+ evaluation_strategy="steps",
65
+ eval_steps=300,
66
+ save_steps=1000,
67
+ save_total_limit=2,
68
+ metric_for_best_model="spearman_cosine",
69
+ greater_is_better=True,
70
+ )
71
+
72
+ transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
73
+ pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
74
+ model = SentenceTransformer(modules=[transformer, pooling])
75
+
76
+ tokenizer = model.tokenizer
77
+ loss = losses.MultipleNegativesRankingLoss(model)
78
+ loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64])
79
+
80
+ dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
81
+ stsb_dev["sentence1"],
82
+ stsb_dev["sentence2"],
83
+ [score / 5 for score in stsb_dev["score"]],
84
+ main_similarity=evaluation.SimilarityFunction.COSINE,
85
+ name="sts-dev",
86
+ )
87
+
88
+ trainer = SentenceTransformerTrainer(
89
+ model=model,
90
+ evaluator=dev_evaluator,
91
+ args=training_args,
92
+ train_dataset=all_nli_ds["train"],
93
+ # eval_dataset=all_nli_ds["validation"],
94
+ loss=loss,
95
+ )
96
+ trainer.train()
97
+
98
+ test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
99
+ stsb_test["sentence1"],
100
+ stsb_test["sentence2"],
101
+ [score / 5 for score in stsb_test["score"]],
102
+ main_similarity=evaluation.SimilarityFunction.COSINE,
103
+ name="sts-test",
104
+ )
105
+ results = test_evaluator(model)
106
+ print(results)