|
--- |
|
pipeline_tag: sentence-similarity |
|
tags: |
|
- sentence-transformers |
|
- feature-extraction |
|
- sentence-similarity |
|
datasets: |
|
- sbx/superlim-2 |
|
language: |
|
- sv |
|
--- |
|
|
|
# jzju/sbert-sv-lim2 |
|
|
|
This model Is trained from [KBLab/bert-base-swedish-cased-new](https://huggingface.co/KBLab/bert-base-swedish-cased-new) with data from [sbx/superlim-2](https://huggingface.co/datasets/sbx/superlim-2) |
|
|
|
This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 256 dimensional dense vector space and can be used for tasks like clustering or semantic search. |
|
|
|
|
|
## Usage (Sentence-Transformers) |
|
|
|
Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed: |
|
|
|
``` |
|
pip install -U sentence-transformers |
|
``` |
|
|
|
Then you can use the model like this: |
|
|
|
```python |
|
from sentence_transformers import SentenceTransformer |
|
sentences = ["This is an example sentence", "Each sentence is converted"] |
|
|
|
model = SentenceTransformer('jzju/sbert-sv-lim2') |
|
embeddings = model.encode(sentences) |
|
print(embeddings) |
|
``` |
|
|
|
## Training Code |
|
```python |
|
from datasets import load_dataset, concatenate_datasets |
|
from sentence_transformers import ( |
|
SentenceTransformer, |
|
InputExample, |
|
losses, |
|
models, |
|
util, |
|
datasets, |
|
) |
|
from torch.utils.data import DataLoader |
|
from torch import nn |
|
import random |
|
|
|
word_embedding_model = models.Transformer( |
|
"KBLab/bert-base-swedish-cased-new", max_seq_length=256 |
|
) |
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
|
dense_model = models.Dense( |
|
in_features=pooling_model.get_sentence_embedding_dimension(), |
|
out_features=256, |
|
activation_function=nn.Tanh(), |
|
) |
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model]) |
|
|
|
|
|
def pair(): |
|
def norm(x): |
|
x["label"] = x["label"] / m |
|
return x |
|
|
|
dd = [] |
|
for sub in ["swepar", "swesim_relatedness", "swesim_similarity"]: |
|
ds = concatenate_datasets( |
|
[d for d in load_dataset("sbx/superlim-2", sub).values()] |
|
) |
|
if "sentence_1" in ds.features: |
|
ds = ds.rename_column("sentence_1", "d1") |
|
ds = ds.rename_column("sentence_2", "d2") |
|
else: |
|
ds = ds.rename_column("word_1", "d1") |
|
ds = ds.rename_column("word_2", "d2") |
|
m = max([d["label"] for d in ds]) |
|
dd.append(ds.map(norm)) |
|
ds = concatenate_datasets(dd) |
|
|
|
train_examples = [] |
|
for d in ds: |
|
train_examples.append(InputExample(texts=[d["d1"], d["d2"]], label=d["label"])) |
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64) |
|
train_loss = losses.CosineSimilarityLoss(model) |
|
model.fit( |
|
train_objectives=[(train_dataloader, train_loss)], epochs=10, warmup_steps=100 |
|
) |
|
|
|
|
|
def nli(): |
|
ds = concatenate_datasets( |
|
[d for d in load_dataset("sbx/superlim-2", "swenli").values()] |
|
) |
|
|
|
def add_to_samples(sent1, sent2, label): |
|
if sent1 not in train_data: |
|
train_data[sent1] = {0: set(), 1: set(), 2: set()} |
|
train_data[sent1][label].add(sent2) |
|
|
|
train_data = {} |
|
for d in ds: |
|
add_to_samples(d["premise"], d["hypothesis"], d["label"]) |
|
add_to_samples(d["hypothesis"], d["premise"], d["label"]) |
|
|
|
train_samples = [] |
|
for sent1, others in train_data.items(): |
|
if len(others[0]) > 0 and len(others[1]) > 0: |
|
train_samples.append( |
|
InputExample( |
|
texts=[ |
|
sent1, |
|
random.choice(list(others[0])), |
|
random.choice(list(others[1])), |
|
] |
|
) |
|
) |
|
train_samples.append( |
|
InputExample( |
|
texts=[ |
|
random.choice(list(others[0])), |
|
sent1, |
|
random.choice(list(others[1])), |
|
] |
|
) |
|
) |
|
train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=64) |
|
train_loss = losses.MultipleNegativesRankingLoss(model) |
|
model.fit( |
|
train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100 |
|
) |
|
|
|
|
|
pair() |
|
nli() |
|
model.save() |
|
|
|
|
|
``` |