File size: 1,253 Bytes
517f5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline

app = FastAPI()

# Initialize the text generation pipeline
pipe = pipeline("text-generation",  model="defog/llama-3-sqlcoder-8b", pad_token_id=2)

class QueryRequest(BaseModel):
    text: str

@app.get("/")
def home():
    return {"message": "SQL Generation Server is running"}


@app.post("/generate")
def generate(request: QueryRequest):
    try:
        text = request.text
        prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
        output = pipe(prompt, max_new_tokens=100)
        
        generated_text = output[0]['generated_text']
        sql_query = generated_text.split("SQL query:")[-1].strip()
        
        # Basic validation
        if not sql_query.lower().startswith(('select', 'show', 'describe')):
            raise ValueError("Generated text is not a valid SQL query")
        
        return {"output": sql_query}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)