Integrate with Sentence Transformers (#3)
Browse files- Integrate with Sentence Transformers + README (bed6830fdf35eb145747402f8dea5803018d07ab)
- Bump up minimum version (aea6a04d99a95e89ffe0b4f5fc20dcd6e75fa85a)
- Replace local-only "." with "jxm/cde-small-v1" (9677008ed99e455a9026e1ea0062fe2cbaf73de5)
- README.md +186 -6
- config_sentence_transformers.json +13 -0
- modules.json +9 -0
- sentence_bert_config.json +1 -0
- sentence_transformers_impl.py +156 -0
README.md
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
---
|
2 |
tags:
|
3 |
- mteb
|
|
|
|
|
4 |
model-index:
|
5 |
- name: cde-small-v1
|
6 |
results:
|
@@ -8660,8 +8662,184 @@ Our new model that naturally integrates "context tokens" into the embedding proc
|
|
8660 |
|
8661 |
Our embedding model needs to be used in *two stages*. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
|
8662 |
|
|
|
8663 |
|
8664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8665 |
|
8666 |
Our model can be loaded using `transformers` out-of-the-box with "trust remote code" enabled. We use the default BERT uncased tokenizer:
|
8667 |
```python
|
@@ -8680,7 +8858,7 @@ query_prefix = "search_query: "
|
|
8680 |
document_prefix = "search_document: "
|
8681 |
```
|
8682 |
|
8683 |
-
|
8684 |
|
8685 |
```python
|
8686 |
minicorpus_size = model.config.transductive_corpus_size
|
@@ -8692,7 +8870,7 @@ minicorpus_docs = tokenizer(
|
|
8692 |
padding=True,
|
8693 |
max_length=512,
|
8694 |
return_tensors="pt"
|
8695 |
-
)
|
8696 |
import torch
|
8697 |
from tqdm.autonotebook import tqdm
|
8698 |
|
@@ -8709,7 +8887,7 @@ for i in tqdm(range(0, len(minicorpus_docs["input_ids"]), batch_size)):
|
|
8709 |
dataset_embeddings = torch.cat(dataset_embeddings)
|
8710 |
```
|
8711 |
|
8712 |
-
|
8713 |
|
8714 |
Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
|
8715 |
```python
|
@@ -8719,7 +8897,7 @@ docs = tokenizer(
|
|
8719 |
padding=True,
|
8720 |
max_length=512,
|
8721 |
return_tensors="pt"
|
8722 |
-
).to(device)
|
8723 |
|
8724 |
with torch.no_grad():
|
8725 |
doc_embeddings = model.second_stage_model(
|
@@ -8739,7 +8917,7 @@ queries = tokenizer(
|
|
8739 |
padding=True,
|
8740 |
max_length=512,
|
8741 |
return_tensors="pt"
|
8742 |
-
).to(device)
|
8743 |
|
8744 |
with torch.no_grad():
|
8745 |
query_embeddings = model.second_stage_model(
|
@@ -8752,6 +8930,8 @@ query_embeddings /= query_embeddings.norm(p=2, dim=1, keepdim=True)
|
|
8752 |
|
8753 |
these embeddings can be compared using dot product, since they're normalized.
|
8754 |
|
|
|
|
|
8755 |
### What if I don't know what my corpus will be ahead of time?
|
8756 |
|
8757 |
If you can't obtain corpus information ahead of time, you still have to pass *something* as the dataset embeddings; our model will work fine in this case, but not quite as well; without corpus information, our model performance drops from 65.0 to 63.8 on MTEB. We provide [some random strings](https://huggingface.co/jxm/cde-small-v1/resolve/main/random_strings.txt) that worked well for us that can be used as a substitute for corpus sampling.
|
|
|
1 |
---
|
2 |
tags:
|
3 |
- mteb
|
4 |
+
- transformers
|
5 |
+
- sentence-transformers
|
6 |
model-index:
|
7 |
- name: cde-small-v1
|
8 |
results:
|
|
|
8662 |
|
8663 |
Our embedding model needs to be used in *two stages*. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
|
8664 |
|
8665 |
+
## With Sentence Transformers
|
8666 |
|
8667 |
+
<details open="">
|
8668 |
+
<summary>Click to learn how to use cde-small-v1 with Sentence Transformers</summary>
|
8669 |
+
|
8670 |
+
### Loading the model
|
8671 |
+
|
8672 |
+
Our model can be loaded using `sentence-transformers` out-of-the-box with "trust remote code" enabled:
|
8673 |
+
```python
|
8674 |
+
from sentence_transformers import SentenceTransformer
|
8675 |
+
|
8676 |
+
model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)
|
8677 |
+
```
|
8678 |
+
|
8679 |
+
#### Note on prefixes
|
8680 |
+
|
8681 |
+
*Nota bene*: Like all state-of-the-art embedding models, our model was trained with task-specific prefixes. To do retrieval, you can use `prompt_name="query"` and `prompt_name="document"` in the `encode` method of the model when embedding queries and documents, respectively.
|
8682 |
+
|
8683 |
+
### First stage
|
8684 |
+
|
8685 |
+
```python
|
8686 |
+
minicorpus_size = model[0].config.transductive_corpus_size
|
8687 |
+
minicorpus_docs = [ ... ] # Put some strings here that are representative of your corpus, for example by calling random.sample(corpus, k=minicorpus_size)
|
8688 |
+
assert len(minicorpus_docs) == minicorpus_size # You must use exactly this many documents in the minicorpus. You can oversample if your corpus is smaller.
|
8689 |
+
|
8690 |
+
dataset_embeddings = model.encode(
|
8691 |
+
minicorpus_docs,
|
8692 |
+
prompt_name="document",
|
8693 |
+
convert_to_tensor=True
|
8694 |
+
)
|
8695 |
+
```
|
8696 |
+
|
8697 |
+
### Running the second stage
|
8698 |
+
|
8699 |
+
Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prompt for documents:
|
8700 |
+
|
8701 |
+
```python
|
8702 |
+
docs = [...]
|
8703 |
+
queries = [...]
|
8704 |
+
|
8705 |
+
doc_embeddings = model.encode(
|
8706 |
+
docs,
|
8707 |
+
prompt_name="document",
|
8708 |
+
dataset_embeddings=dataset_embeddings,
|
8709 |
+
convert_to_tensor=True,
|
8710 |
+
)
|
8711 |
+
query_embeddings = model.encode(
|
8712 |
+
queries,
|
8713 |
+
prompt_name="query",
|
8714 |
+
dataset_embeddings=dataset_embeddings,
|
8715 |
+
convert_to_tensor=True,
|
8716 |
+
)
|
8717 |
+
```
|
8718 |
+
|
8719 |
+
these embeddings can be compared using cosine similarity via `model.similarity`:
|
8720 |
+
```python
|
8721 |
+
similarities = model.similarity(query_embeddings, doc_embeddings)
|
8722 |
+
topk_values, topk_indices = similarities.topk(5)
|
8723 |
+
```
|
8724 |
+
|
8725 |
+
<details>
|
8726 |
+
<summary>Click here for a full copy-paste ready example</summary>
|
8727 |
+
|
8728 |
+
```python
|
8729 |
+
from sentence_transformers import SentenceTransformer
|
8730 |
+
from datasets import load_dataset
|
8731 |
+
|
8732 |
+
# 1. Load the Sentence Transformer model
|
8733 |
+
model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)
|
8734 |
+
context_docs_size = model[0].config.transductive_corpus_size # 512
|
8735 |
+
|
8736 |
+
# 2. Load the dataset: context dataset, docs, and queries
|
8737 |
+
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
|
8738 |
+
dataset.shuffle(seed=42)
|
8739 |
+
# 10 queries, 512 context docs, 500 docs
|
8740 |
+
queries = dataset["query"][:10]
|
8741 |
+
docs = dataset["answer"][:2000]
|
8742 |
+
context_docs = dataset["answer"][-context_docs_size:] # Last 512 docs
|
8743 |
+
|
8744 |
+
# 3. First stage: embed the context docs
|
8745 |
+
dataset_embeddings = model.encode(
|
8746 |
+
context_docs,
|
8747 |
+
prompt_name="document",
|
8748 |
+
convert_to_tensor=True,
|
8749 |
+
)
|
8750 |
+
|
8751 |
+
# 4. Second stage: embed the docs and queries
|
8752 |
+
doc_embeddings = model.encode(
|
8753 |
+
docs,
|
8754 |
+
prompt_name="document",
|
8755 |
+
dataset_embeddings=dataset_embeddings,
|
8756 |
+
convert_to_tensor=True,
|
8757 |
+
)
|
8758 |
+
query_embeddings = model.encode(
|
8759 |
+
queries,
|
8760 |
+
prompt_name="query",
|
8761 |
+
dataset_embeddings=dataset_embeddings,
|
8762 |
+
convert_to_tensor=True,
|
8763 |
+
)
|
8764 |
+
|
8765 |
+
# 5. Compute the similarity between the queries and docs
|
8766 |
+
similarities = model.similarity(query_embeddings, doc_embeddings)
|
8767 |
+
topk_values, topk_indices = similarities.topk(5)
|
8768 |
+
print(topk_values)
|
8769 |
+
print(topk_indices)
|
8770 |
+
|
8771 |
+
"""
|
8772 |
+
tensor([[0.5495, 0.5426, 0.5423, 0.5292, 0.5286],
|
8773 |
+
[0.6357, 0.6334, 0.6177, 0.5862, 0.5794],
|
8774 |
+
[0.7648, 0.5452, 0.5000, 0.4959, 0.4881],
|
8775 |
+
[0.6802, 0.5225, 0.5178, 0.5160, 0.5075],
|
8776 |
+
[0.6947, 0.5843, 0.5619, 0.5344, 0.5298],
|
8777 |
+
[0.7742, 0.7742, 0.7742, 0.7231, 0.6224],
|
8778 |
+
[0.8853, 0.6667, 0.5829, 0.5795, 0.5769],
|
8779 |
+
[0.6911, 0.6127, 0.6003, 0.5986, 0.5936],
|
8780 |
+
[0.6796, 0.6053, 0.6000, 0.5911, 0.5884],
|
8781 |
+
[0.7624, 0.5589, 0.5428, 0.5278, 0.5275]], device='cuda:0')
|
8782 |
+
tensor([[ 0, 296, 234, 1651, 1184],
|
8783 |
+
[1542, 466, 438, 1207, 1911],
|
8784 |
+
[ 2, 1562, 632, 1852, 382],
|
8785 |
+
[ 3, 694, 932, 1765, 662],
|
8786 |
+
[ 4, 35, 747, 26, 432],
|
8787 |
+
[ 534, 175, 5, 1495, 575],
|
8788 |
+
[ 6, 1802, 1875, 747, 21],
|
8789 |
+
[ 7, 1913, 1936, 640, 6],
|
8790 |
+
[ 8, 747, 167, 1318, 1743],
|
8791 |
+
[ 9, 1583, 1145, 219, 357]], device='cuda:0')
|
8792 |
+
"""
|
8793 |
+
# As you can see, almost every query_i has document_i as the most similar document.
|
8794 |
+
|
8795 |
+
# 6. Print the top-k results
|
8796 |
+
for query_idx, top_doc_idx in enumerate(topk_indices[:, 0]):
|
8797 |
+
print(f"Query {query_idx}: {queries[query_idx]}")
|
8798 |
+
print(f"Top Document: {docs[top_doc_idx]}")
|
8799 |
+
print()
|
8800 |
+
"""
|
8801 |
+
Query 0: when did richmond last play in a preliminary final
|
8802 |
+
Top Document: Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next.
|
8803 |
+
|
8804 |
+
Query 1: who sang what in the world's come over you
|
8805 |
+
Top Document: Life's What You Make It (Talk Talk song) "Life's What You Make It" is a song by the English band Talk Talk. It was released as a single in 1986, the first from the band's album The Colour of Spring. The single was a hit in the UK, peaking at No. 16, and charted in numerous other countries, often reaching the Top 20.
|
8806 |
+
|
8807 |
+
Query 2: who produces the most wool in the world
|
8808 |
+
Top Document: Wool Global wool production is about 2 million tonnes per year, of which 60% goes into apparel. Wool comprises ca 3% of the global textile market, but its value is higher owing to dying and other modifications of the material.[1] Australia is a leading producer of wool which is mostly from Merino sheep but has been eclipsed by China in terms of total weight.[30] New Zealand (2016) is the third-largest producer of wool, and the largest producer of crossbred wool. Breeds such as Lincoln, Romney, Drysdale, and Elliotdale produce coarser fibers, and wool from these sheep is usually used for making carpets.
|
8809 |
+
|
8810 |
+
Query 3: where does alaska the last frontier take place
|
8811 |
+
Top Document: Alaska: The Last Frontier Alaska: The Last Frontier is an American reality cable television series on the Discovery Channel, currently in its 7th season of broadcast. The show documents the extended Kilcher family, descendants of Swiss immigrants and Alaskan pioneers, Yule and Ruth Kilcher, at their homestead 11 miles outside of Homer.[1] By living without plumbing or modern heating, the clan chooses to subsist by farming, hunting and preparing for the long winters.[2] The Kilcher family are relatives of the singer Jewel,[1][3] who has appeared on the show.[4]
|
8812 |
+
|
8813 |
+
Query 4: a day to remember all i want cameos
|
8814 |
+
Top Document: All I Want (A Day to Remember song) The music video for the song, which was filmed in October 2010,[4] was released on January 6, 2011.[5] It features cameos of numerous popular bands and musicians. The cameos are: Tom Denney (A Day to Remember's former guitarist), Pete Wentz, Winston McCall of Parkway Drive, The Devil Wears Prada, Bring Me the Horizon, Sam Carter of Architects, Tim Lambesis of As I Lay Dying, Silverstein, Andrew WK, August Burns Red, Seventh Star, Matt Heafy of Trivium, Vic Fuentes of Pierce the Veil, Mike Herrera of MxPx, and Set Your Goals.[5] Rock Sound called the video "quite excellent".[5]
|
8815 |
+
|
8816 |
+
Query 5: what does the red stripes mean on the american flag
|
8817 |
+
Top Document: Flag of the United States The flag of the United States of America, often referred to as the American flag, is the national flag of the United States. It consists of thirteen equal horizontal stripes of red (top and bottom) alternating with white, with a blue rectangle in the canton (referred to specifically as the "union") bearing fifty small, white, five-pointed stars arranged in nine offset horizontal rows, where rows of six stars (top and bottom) alternate with rows of five stars. The 50 stars on the flag represent the 50 states of the United States of America, and the 13 stripes represent the thirteen British colonies that declared independence from the Kingdom of Great Britain, and became the first states in the U.S.[1] Nicknames for the flag include The Stars and Stripes,[2] Old Glory,[3] and The Star-Spangled Banner.
|
8818 |
+
|
8819 |
+
Query 6: where did they film diary of a wimpy kid
|
8820 |
+
Top Document: Diary of a Wimpy Kid (film) Filming of Diary of a Wimpy Kid was in Vancouver and wrapped up on October 16, 2009.
|
8821 |
+
|
8822 |
+
Query 7: where was beasts of the southern wild filmed
|
8823 |
+
Top Document: Beasts of the Southern Wild The film's fictional setting, "Isle de Charles Doucet", known to its residents as the Bathtub, was inspired by several isolated and independent fishing communities threatened by erosion, hurricanes and rising sea levels in Louisiana's Terrebonne Parish, most notably the rapidly eroding Isle de Jean Charles. It was filmed in Terrebonne Parish town Montegut.[5]
|
8824 |
+
|
8825 |
+
Query 8: what part of the country are you likely to find the majority of the mollisols
|
8826 |
+
Top Document: Mollisol Mollisols occur in savannahs and mountain valleys (such as Central Asia, or the North American Great Plains). These environments have historically been strongly influenced by fire and abundant pedoturbation from organisms such as ants and earthworms. It was estimated that in 2003, only 14 to 26 percent of grassland ecosystems still remained in a relatively natural state (that is, they were not used for agriculture due to the fertility of the A horizon). Globally, they represent ~7% of ice-free land area. As the world's most agriculturally productive soil order, the Mollisols represent one of the more economically important soil orders.
|
8827 |
+
|
8828 |
+
Query 9: when did fosters home for imaginary friends start
|
8829 |
+
Top Document: Foster's Home for Imaginary Friends McCracken conceived the series after adopting two dogs from an animal shelter and applying the concept to imaginary friends. The show first premiered on Cartoon Network on August 13, 2004, as a 90-minute television film. On August 20, it began its normal run of twenty-to-thirty-minute episodes on Fridays, at 7 pm. The series finished its run on May 3, 2009, with a total of six seasons and seventy-nine episodes. McCracken left Cartoon Network shortly after the series ended. Reruns have aired on Boomerang from August 11, 2012 to November 3, 2013 and again from June 1, 2014 to April 3, 2017.
|
8830 |
+
"""
|
8831 |
+
```
|
8832 |
+
|
8833 |
+
</details>
|
8834 |
+
|
8835 |
+
</details>
|
8836 |
+
|
8837 |
+
## With Transformers
|
8838 |
+
|
8839 |
+
<details>
|
8840 |
+
<summary>Click to learn how to use cde-small-v1 with Transformers</summary>
|
8841 |
+
|
8842 |
+
### Loading the model
|
8843 |
|
8844 |
Our model can be loaded using `transformers` out-of-the-box with "trust remote code" enabled. We use the default BERT uncased tokenizer:
|
8845 |
```python
|
|
|
8858 |
document_prefix = "search_document: "
|
8859 |
```
|
8860 |
|
8861 |
+
### First stage
|
8862 |
|
8863 |
```python
|
8864 |
minicorpus_size = model.config.transductive_corpus_size
|
|
|
8870 |
padding=True,
|
8871 |
max_length=512,
|
8872 |
return_tensors="pt"
|
8873 |
+
).to(model.device)
|
8874 |
import torch
|
8875 |
from tqdm.autonotebook import tqdm
|
8876 |
|
|
|
8887 |
dataset_embeddings = torch.cat(dataset_embeddings)
|
8888 |
```
|
8889 |
|
8890 |
+
### Running the second stage
|
8891 |
|
8892 |
Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
|
8893 |
```python
|
|
|
8897 |
padding=True,
|
8898 |
max_length=512,
|
8899 |
return_tensors="pt"
|
8900 |
+
).to(model.device)
|
8901 |
|
8902 |
with torch.no_grad():
|
8903 |
doc_embeddings = model.second_stage_model(
|
|
|
8917 |
padding=True,
|
8918 |
max_length=512,
|
8919 |
return_tensors="pt"
|
8920 |
+
).to(model.device)
|
8921 |
|
8922 |
with torch.no_grad():
|
8923 |
query_embeddings = model.second_stage_model(
|
|
|
8930 |
|
8931 |
these embeddings can be compared using dot product, since they're normalized.
|
8932 |
|
8933 |
+
</details>
|
8934 |
+
|
8935 |
### What if I don't know what my corpus will be ahead of time?
|
8936 |
|
8937 |
If you can't obtain corpus information ahead of time, you still have to pass *something* as the dataset embeddings; our model will work fine in this case, but not quite as well; without corpus information, our model performance drops from 65.0 to 63.8 on MTEB. We provide [some random strings](https://huggingface.co/jxm/cde-small-v1/resolve/main/random_strings.txt) that worked well for us that can be used as a substitute for corpus sampling.
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "3.1.0",
|
4 |
+
"transformers": "4.43.4",
|
5 |
+
"pytorch": "2.5.0.dev20240807+cu121"
|
6 |
+
},
|
7 |
+
"prompts": {
|
8 |
+
"query": "search_query: ",
|
9 |
+
"document": "search_document: "
|
10 |
+
},
|
11 |
+
"default_prompt_name": null,
|
12 |
+
"similarity_fn_name": "cosine"
|
13 |
+
}
|
modules.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers_impl.Transformer",
|
7 |
+
"kwargs": ["dataset_embeddings"]
|
8 |
+
}
|
9 |
+
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
sentence_transformers_impl.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from typing import Any, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class Transformer(nn.Module):
|
16 |
+
"""Hugging Face AutoModel to generate token embeddings.
|
17 |
+
Loads the correct class, e.g. BERT / RoBERTa etc.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model_name_or_path: Hugging Face models name
|
21 |
+
(https://huggingface.co/models)
|
22 |
+
max_seq_length: Truncate any inputs longer than max_seq_length
|
23 |
+
model_args: Keyword arguments passed to the Hugging Face
|
24 |
+
Transformers model
|
25 |
+
tokenizer_args: Keyword arguments passed to the Hugging Face
|
26 |
+
Transformers tokenizer
|
27 |
+
config_args: Keyword arguments passed to the Hugging Face
|
28 |
+
Transformers config
|
29 |
+
cache_dir: Cache dir for Hugging Face Transformers to store/load
|
30 |
+
models
|
31 |
+
do_lower_case: If true, lowercases the input (independent if the
|
32 |
+
model is cased or not)
|
33 |
+
tokenizer_name_or_path: Name or path of the tokenizer. When
|
34 |
+
None, then model_name_or_path is used
|
35 |
+
backend: Backend used for model inference. Can be `torch`, `onnx`,
|
36 |
+
or `openvino`. Default is `torch`.
|
37 |
+
"""
|
38 |
+
|
39 |
+
save_in_root: bool = True
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
model_name_or_path: str,
|
44 |
+
model_args: dict[str, Any] | None = None,
|
45 |
+
tokenizer_args: dict[str, Any] | None = None,
|
46 |
+
config_args: dict[str, Any] | None = None,
|
47 |
+
cache_dir: str | None = None,
|
48 |
+
**kwargs,
|
49 |
+
) -> None:
|
50 |
+
super().__init__()
|
51 |
+
if model_args is None:
|
52 |
+
model_args = {}
|
53 |
+
if tokenizer_args is None:
|
54 |
+
tokenizer_args = {}
|
55 |
+
if config_args is None:
|
56 |
+
config_args = {}
|
57 |
+
|
58 |
+
if not model_args.get("trust_remote_code", False):
|
59 |
+
raise ValueError(
|
60 |
+
"You need to set `trust_remote_code=True` to load this model."
|
61 |
+
)
|
62 |
+
|
63 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
64 |
+
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
|
65 |
+
|
66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
67 |
+
"bert-base-uncased",
|
68 |
+
cache_dir=cache_dir,
|
69 |
+
**tokenizer_args,
|
70 |
+
)
|
71 |
+
|
72 |
+
def __repr__(self) -> str:
|
73 |
+
return f"Transformer({self.get_config_dict()}) with Transformer model: {self.auto_model.__class__.__name__} "
|
74 |
+
|
75 |
+
def forward(self, features: dict[str, torch.Tensor], dataset_embeddings: Optional[torch.Tensor] = None, **kwargs) -> dict[str, torch.Tensor]:
|
76 |
+
"""Returns token_embeddings, cls_token"""
|
77 |
+
# If we don't have embeddings, then run the 1st stage model.
|
78 |
+
# If we do, then run the 2nd stage model.
|
79 |
+
if dataset_embeddings is None:
|
80 |
+
sentence_embedding = self.auto_model.first_stage_model(
|
81 |
+
input_ids=features["input_ids"],
|
82 |
+
attention_mask=features["attention_mask"],
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
sentence_embedding = self.auto_model.second_stage_model(
|
86 |
+
input_ids=features["input_ids"],
|
87 |
+
attention_mask=features["attention_mask"],
|
88 |
+
dataset_embeddings=dataset_embeddings,
|
89 |
+
)
|
90 |
+
|
91 |
+
features["sentence_embedding"] = sentence_embedding
|
92 |
+
return features
|
93 |
+
|
94 |
+
def get_word_embedding_dimension(self) -> int:
|
95 |
+
return self.auto_model.config.hidden_size
|
96 |
+
|
97 |
+
def tokenize(
|
98 |
+
self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
|
99 |
+
) -> dict[str, torch.Tensor]:
|
100 |
+
"""Tokenizes a text and maps tokens to token-ids"""
|
101 |
+
output = {}
|
102 |
+
if isinstance(texts[0], str):
|
103 |
+
to_tokenize = [texts]
|
104 |
+
elif isinstance(texts[0], dict):
|
105 |
+
to_tokenize = []
|
106 |
+
output["text_keys"] = []
|
107 |
+
for lookup in texts:
|
108 |
+
text_key, text = next(iter(lookup.items()))
|
109 |
+
to_tokenize.append(text)
|
110 |
+
output["text_keys"].append(text_key)
|
111 |
+
to_tokenize = [to_tokenize]
|
112 |
+
else:
|
113 |
+
batch1, batch2 = [], []
|
114 |
+
for text_tuple in texts:
|
115 |
+
batch1.append(text_tuple[0])
|
116 |
+
batch2.append(text_tuple[1])
|
117 |
+
to_tokenize = [batch1, batch2]
|
118 |
+
|
119 |
+
max_seq_length = self.config.max_seq_length
|
120 |
+
output.update(
|
121 |
+
self.tokenizer(
|
122 |
+
*to_tokenize,
|
123 |
+
padding=padding,
|
124 |
+
truncation="longest_first",
|
125 |
+
return_tensors="pt",
|
126 |
+
max_length=max_seq_length,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
return output
|
130 |
+
|
131 |
+
def get_config_dict(self) -> dict[str, Any]:
|
132 |
+
return {}
|
133 |
+
|
134 |
+
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
135 |
+
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
|
136 |
+
self.tokenizer.save_pretrained(output_path)
|
137 |
+
|
138 |
+
with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
|
139 |
+
json.dump(self.get_config_dict(), fOut, indent=2)
|
140 |
+
|
141 |
+
@classmethod
|
142 |
+
def load(cls, input_path: str) -> Transformer:
|
143 |
+
sbert_config_path = os.path.join(input_path, "sentence_bert_config.json")
|
144 |
+
if not os.path.exists(sbert_config_path):
|
145 |
+
return cls(model_name_or_path=input_path)
|
146 |
+
|
147 |
+
with open(sbert_config_path) as fIn:
|
148 |
+
config = json.load(fIn)
|
149 |
+
# Don't allow configs to set trust_remote_code
|
150 |
+
if "model_args" in config and "trust_remote_code" in config["model_args"]:
|
151 |
+
config["model_args"].pop("trust_remote_code")
|
152 |
+
if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
|
153 |
+
config["tokenizer_args"].pop("trust_remote_code")
|
154 |
+
if "config_args" in config and "trust_remote_code" in config["config_args"]:
|
155 |
+
config["config_args"].pop("trust_remote_code")
|
156 |
+
return cls(model_name_or_path=input_path, **config)
|