etownsupport commited on
Commit
4bd67d4
1 Parent(s): 0951ee0

Update etown_mxbai/router.py

Browse files
Files changed (1) hide show
  1. etown_mxbai/router.py +52 -69
etown_mxbai/router.py CHANGED
@@ -1,70 +1,53 @@
1
- from pydantic import BaseModel
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
- # from sentence_transformers import SentenceTransformer
5
- # from sentence_transformers.util import cos_sim
6
- from typing import List
7
- import os, platform, time
8
- from transformers import AutoTokenizer
9
- import fastembed
10
- from fastembed import SparseEmbedding, SparseTextEmbedding, TextEmbedding
11
- import numpy as np
12
-
13
-
14
- sparse_model_name = "prithvida/Splade_PP_en_v1"
15
- sparse_model = SparseTextEmbedding(model_name=sparse_model_name, batch_size=32)
16
-
17
- class Validation(BaseModel):
18
- prompt: List[str]
19
-
20
- from etown_mxbai import app
21
-
22
- app.add_middleware(
23
- CORSMiddleware,
24
- allow_origins=["*"],
25
- allow_credentials=True,
26
- allow_methods=["*"],
27
- allow_headers=["*"],
28
- )
29
-
30
- @app.post("/api/generate", summary="Generate embeddings", tags=["Generate"])
31
- def inference(item: Validation):
32
- try:
33
- start_time = time.time()
34
- embeddings = list(sparse_model.embed(item.prompt, batch_size=5)) # Assuming 'model' is defined elsewhere
35
-
36
- serializable_embeddings = []
37
- for embedding in embeddings:
38
- # Assuming embedding object has attributes values and indices
39
- if isinstance(embedding, SparseEmbedding):
40
- values = embedding.values
41
- indices = embedding.indices
42
- serializable_embeddings.append({
43
- "values": values.tolist() if isinstance(values, np.ndarray) else values,
44
- "indices": indices.tolist() if isinstance(indices, np.ndarray) else indices
45
- })
46
- else:
47
- # Fallback for other types of embeddings
48
- serializable_embeddings.append({
49
- "values": embedding.tolist() if isinstance(embedding, np.ndarray) else str(embedding),
50
- "indices": list(range(len(embedding))) if isinstance(embedding, (np.ndarray, list)) else []
51
- })
52
-
53
- end_time = time.time()
54
- time_taken = end_time - start_time # Calculate the time taken
55
-
56
- return JSONResponse(content={
57
- "embeddings": serializable_embeddings,
58
- "time_taken": f"{time_taken:.2f} seconds",
59
- "Number_of_sentence_processed": len(item.prompt), # Assuming you want to count words, not characters
60
- "Model_response_space" : "prithvida/Splade_PP_en_v1",
61
- "status_code" : 200
62
- })
63
- except Exception as e:
64
- print(f"An error occurred: {str(e)}") # Simple print statement for logging; consider using proper logging
65
- return JSONResponse(content={
66
- "error": "An error occurred during processing.",
67
- "details": str(e),
68
- "Model_response_space" : "prithvida/Splade_PP_en_v1",
69
- "status_code" : 500
70
  })
 
1
+ from pydantic import BaseModel
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ # from sentence_transformers import SentenceTransformer
5
+ # from sentence_transformers.util import cos_sim
6
+ from typing import List
7
+ import os, platform, time
8
+ from transformers import AutoTokenizer
9
+ import fastembed
10
+ from fastembed import SparseEmbedding, SparseTextEmbedding, TextEmbedding
11
+ import numpy as np
12
+
13
+
14
+ sparse_model_name = "prithvida/Splade_PP_en_v1"
15
+ sparse_model = SparseTextEmbedding(model_name=sparse_model_name, batch_size=32)
16
+
17
+ class Validation(BaseModel):
18
+ prompt: List[str]
19
+
20
+ from etown_mxbai import app
21
+
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ @app.post("/api/generate", summary="Generate embeddings", tags=["Generate"])
31
+ def inference(item: Validation):
32
+ try:
33
+ start_time = time.time()
34
+ embeddings = list(sparse_model.embed(item.prompt, batch_size=32)) # Assuming 'model' is defined elsewhere
35
+
36
+ end_time = time.time()
37
+ time_taken = end_time - start_time # Calculate the time taken
38
+
39
+ return JSONResponse(content={
40
+ "embeddings": embeddings,
41
+ "time_taken": f"{time_taken:.2f} seconds",
42
+ "Number_of_sentence_processed": len(item.prompt), # Assuming you want to count words, not characters
43
+ "Model_response_space" : "prithvida/Splade_PP_en_v1",
44
+ "status_code" : 200
45
+ })
46
+ except Exception as e:
47
+ print(f"An error occurred: {str(e)}") # Simple print statement for logging; consider using proper logging
48
+ return JSONResponse(content={
49
+ "error": "An error occurred during processing.",
50
+ "details": str(e),
51
+ "Model_response_space" : "prithvida/Splade_PP_en_v1",
52
+ "status_code" : 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  })