root commited on
Commit
71eb861
1 Parent(s): 05dfd9d

api key added

Browse files
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
- from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
 
4
  from utils import retrive_context, generate_response
5
 
6
 
@@ -11,32 +12,35 @@ app = FastAPI()
11
  class QueryRequest(BaseModel):
12
  # Asked query should be in string format
13
  query: str
14
- api_key: str
15
-
16
 
17
  class QueryResponse(BaseModel):
18
  # Response should be in string format
19
  response: str
20
 
21
 
 
 
 
 
 
 
 
22
 
23
  @app.post("/infer", response_model=QueryResponse)
24
- def infer(query_request: QueryRequest):
25
- key = query_request.api_key
26
  query = query_request.query
27
-
28
- if (key == os.getenv("API_KEY")):
29
- context = retrive_context(query)
30
- if context == 500:
31
- raise HTTPException(status_code=500, detail="Error retrieving context")
32
-
33
- response = generate_response(query, context)
34
- if response == 500:
35
- raise HTTPException(status_code=500, detail="Error generating response")
36
-
37
- return QueryResponse(response=response)
38
- else:
39
- raise HTTPException(status_code=401, detail="Invalid api key")
40
 
41
  # Root endpoint for testing
42
  @app.get("/")
 
1
  import os
2
+ from fastapi import FastAPI, HTTPException, Depends
3
  from pydantic import BaseModel
4
+ from fastapi.security.api_key import APIKeyHeader
5
  from utils import retrive_context, generate_response
6
 
7
 
 
12
  class QueryRequest(BaseModel):
13
  # Asked query should be in string format
14
  query: str
15
+
 
16
 
17
  class QueryResponse(BaseModel):
18
  # Response should be in string format
19
  response: str
20
 
21
 
22
+ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
23
+
24
+ def get_api_key(api_key: str = Depends(api_key_header)):
25
+ if api_key == os.getenv("API_KEY"):
26
+ return api_key
27
+ else:
28
+ raise HTTPException(status_code=401, detail="Invalid API key")
29
 
30
  @app.post("/infer", response_model=QueryResponse)
31
+ def infer(query_request: QueryRequest, api_key: str = Depends(get_api_key)):
 
32
  query = query_request.query
33
+
34
+ context = retrive_context(query)
35
+ if context == 500:
36
+ raise HTTPException(status_code=500, detail="Error retrieving context")
37
+
38
+ response = generate_response(query, context)
39
+ if response == 500:
40
+ raise HTTPException(status_code=500, detail="Error generating response")
41
+
42
+ return QueryResponse(response=response)
43
+
 
 
44
 
45
  # Root endpoint for testing
46
  @app.get("/")