File size: 1,776 Bytes
17f3a9b
 
 
 
d794995
d4b1508
 
 
62f31c8
 
 
17f3a9b
2ae9fb3
17f3a9b
 
 
2ae9fb3
17f3a9b
 
d4b1508
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
2ae9fb3
17f3a9b
 
 
 
8d84024
d4b1508
62f31c8
 
 
 
d4b1508
62f31c8
 
 
d4b1508
62f31c8
 
 
 
 
 
 
 
d4b1508
efad2c7
 
 
a65e7e5
62f31c8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.wsgi import WSGIMiddleware


from transformers import pipeline

from RequestModel import PredictRequest
from us_stock import fetch_symbols

app = FastAPI()  # 创建 FastAPI 应用

# 定义请求模型
class TextRequest(BaseModel):
    text: str

# 定义两个 API 路由处理函数
@app.post("/api/aaa")
async def api_aaa_post(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

# 定义两个 API 路由处理函数
@app.post("/aaa")
async def aaa(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}


# 定义两个 API 路由处理函数
@app.get("/aaa")
async def api_aaa_get(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

@app.post("/api/bbb")
async def api_bbb(request: TextRequest):
    result = request.text + 'bbb'
    return {"result": result}


@app.on_event("startup")
async def initialize_symbols():
    # 在 FastAPI 启动时初始化变量
    await fetch_symbols()

@app.post("/api/predict")
async def predict(request: PredictRequest):
    from blkeras import predict

    try:
        input_text = request.text  # FastAPI 会自动解析为 PredictRequest 对象
        affected_stock_codes = request.stock_codes
        print("Input text:", input_text)
        print("Affected stock codes:", affected_stock_codes)
        return predict(input_text, affected_stock_codes)
    except Exception as e:
        return {"error": str(e)}

@app.get("/")
async def root():
    return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)