root commited on
Commit
05dfd9d
1 Parent(s): c16e0ec

api key added

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from utils import retrive_context, generate_response
@@ -9,7 +10,8 @@ app = FastAPI()
9
 
10
  class QueryRequest(BaseModel):
11
  # Asked query should be in string format
12
- query: str
 
13
 
14
 
15
  class QueryResponse(BaseModel):
@@ -20,17 +22,21 @@ class QueryResponse(BaseModel):
20
 
21
  @app.post("/infer", response_model=QueryResponse)
22
  def infer(query_request: QueryRequest):
 
23
  query = query_request.query
24
- context = retrive_context(query)
25
- if context == 500:
26
- raise HTTPException(status_code=500, detail="Error retrieving context")
27
 
28
- response = generate_response(query, context)
29
- if response == 500:
30
- raise HTTPException(status_code=500, detail="Error generating response")
 
31
 
32
- return QueryResponse(response=response)
33
-
 
 
 
 
 
34
 
35
  # Root endpoint for testing
36
  @app.get("/")
 
1
+ import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from utils import retrive_context, generate_response
 
10
 
11
  class QueryRequest(BaseModel):
12
  # Asked query should be in string format
13
+ query: str
14
+ api_key: str
15
 
16
 
17
  class QueryResponse(BaseModel):
 
22
 
23
  @app.post("/infer", response_model=QueryResponse)
24
  def infer(query_request: QueryRequest):
25
+ key = query_request.api_key
26
  query = query_request.query
 
 
 
27
 
28
+ if (key == os.getenv("API_KEY")):
29
+ context = retrive_context(query)
30
+ if context == 500:
31
+ raise HTTPException(status_code=500, detail="Error retrieving context")
32
 
33
+ response = generate_response(query, context)
34
+ if response == 500:
35
+ raise HTTPException(status_code=500, detail="Error generating response")
36
+
37
+ return QueryResponse(response=response)
38
+ else:
39
+ raise HTTPException(status_code=401, detail="Invalid api key")
40
 
41
  # Root endpoint for testing
42
  @app.get("/")