edwardjross
commited on
Commit
•
7e1ab62
1
Parent(s):
120c5ad
Enable saving in SentenceTransformers by adding get_config_dict
Browse filesAdd `get_config_dict` method in custom_st `Transformer` for usage in SentenceTransformers that is used in the `save` method.
This fixes a bug that calling `.save` after loading with sentence-transformers 3.1 raised an `AttributeError` so now the following code successfully saves the model to local directory `jina-embeddings-v3`:
```
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
model.save("jina-embeddings-v3")
```
- custom_st.py +3 -0
custom_st.py
CHANGED
@@ -160,6 +160,9 @@ class Transformer(nn.Module):
|
|
160 |
)
|
161 |
return output
|
162 |
|
|
|
|
|
|
|
163 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
164 |
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
|
165 |
self.tokenizer.save_pretrained(output_path)
|
|
|
160 |
)
|
161 |
return output
|
162 |
|
163 |
+
def get_config_dict(self) -> dict[str, Any]:
|
164 |
+
return {key: self.__dict__[key] for key in self.config_keys}
|
165 |
+
|
166 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
167 |
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
|
168 |
self.tokenizer.save_pretrained(output_path)
|