lhoestq HF staff commited on
Commit
79ad082
1 Parent(s): 68aadb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -6,6 +6,7 @@ import requests
6
  import uvicorn
7
  from fastapi import FastAPI
8
  from gliner import GLiNER
 
9
 
10
 
11
  model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
@@ -155,14 +156,14 @@ app = FastAPI()
155
 
156
  @app.head("/predict")
157
  def predict_head():
158
- return {}
159
 
160
  @app.get("/predict")
161
  def predict_get(text: str = "", labels: str = "", threshold: float = 0.3, nested_ner: bool = False):
162
  predict_response = requests.post('http://localhost:7860/call/predict', json={'data': [text, labels, threshold, nested_ner]}).json()
163
  if "event_id" not in predict_response:
164
  return predict_response
165
- return json.loads(requests.get(f'http://localhost:7860/call/predict/{predict_response["event_id"]}').text.split("data: ", 1)[-1])
166
 
167
  if __name__ == "__main__":
168
  app = gr.mount_gradio_app(app, demo, path="/")
 
6
  import uvicorn
7
  from fastapi import FastAPI
8
  from gliner import GLiNER
9
+ from starlette.responses import StreamingResponse, JSONResponse
10
 
11
 
12
  model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
 
156
 
157
  @app.head("/predict")
158
  def predict_head():
159
+ return StreamingResponse("", media_type="application/json")
160
 
161
  @app.get("/predict")
162
  def predict_get(text: str = "", labels: str = "", threshold: float = 0.3, nested_ner: bool = False):
163
  predict_response = requests.post('http://localhost:7860/call/predict', json={'data': [text, labels, threshold, nested_ner]}).json()
164
  if "event_id" not in predict_response:
165
  return predict_response
166
+ return JSONResponse(json.loads(requests.get(f'http://localhost:7860/call/predict/{predict_response["event_id"]}').text.split("data: ", 1)[-1]))
167
 
168
  if __name__ == "__main__":
169
  app = gr.mount_gradio_app(app, demo, path="/")