File size: 5,122 Bytes
035476a
5473610
 
283a0f0
 
5473610
 
 
 
 
 
 
 
 
 
 
 
3de53f6
5473610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
035476a
 
5473610
035476a
 
 
 
 
5473610
035476a
 
 
 
 
5473610
 
283a0f0
 
5473610
 
283a0f0
 
3de53f6
283a0f0
 
5473610
035476a
 
 
 
5473610
035476a
3de53f6
035476a
 
 
 
 
 
283a0f0
 
5473610
 
283a0f0
 
5473610
 
 
 
 
283a0f0
3de53f6
 
 
 
 
 
 
 
 
5473610
283a0f0
5473610
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from transformers import pipeline
import mysql.connector 
import json
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

app = FastAPI()

# Initialize the text generation pipeline
pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2)

class QueryRequest(BaseModel):
    query: str

def get_db_connection():
    """Create a new database connection."""
    try:
        connection = mysql.connector.connect(
            host=os.getenv("DB_HOST"),
            user=os.getenv("DB_USER"),
            password=os.getenv("DB_PASSWORD"),
            database=os.getenv("DB_NAME"),
            raise_on_warnings=True
        )
        return connection
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return None

def get_database_schema():
    """Function to retrieve the database schema dynamically."""
    schema = {}
    try:
        conn = get_db_connection()
        if conn is None:
            raise Exception("Failed to connect to the database.")
        
        cursor = conn.cursor()

        # Query to get table names
        cursor.execute("SHOW TABLES")
        tables = cursor.fetchall()

        for table in tables:
            table_name = table[0]
            cursor.execute(f"DESCRIBE {table_name}")
            columns = cursor.fetchall()
            schema[table_name] = [column[0] for column in columns]

        cursor.close()
        conn.close()
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return {}
    except Exception as e:
        print(f"An error occurred: {e}")
        return {}

    return schema

@app.get("/")
def home():
    return {"message": "SQL Generation Server is running"}

@app.api_route("/query", methods=["GET", "POST"])
async def handle_query(request: Request):
    try:
        if request.method == "POST":
            request_data = await request.json()
            text = request_data.get("query", "")
        elif request.method == "GET":
            text = request.query_params.get("query", "")
        
        print("Received query:", text)  # Debugging: Print the received query

        if not text:
            raise ValueError("No query provided.")

        # Fetch the database schema
        schema = get_database_schema()
        schema_str = json.dumps(schema, indent=4)
        print("Fetched schema:", schema)  # Debugging: Print the fetched schema
        
        system_message = f"""

        You are a helpful, cheerful database assistant. 

        Use the following dynamically retrieved database schema when creating your answers:



        {schema_str}

         When creating your answers, consider the following:



        1. If a query involves a column or value that is not present in the provided database schema, correct it and mention the correction in the summary. If a column or value is missing, provide an explanation of the issue and adjust the query accordingly.

        2. If there is a spelling mistake in the column name or value, attempt to correct it by matching the closest possible column or value from the schema. Mention this correction in the summary to clarify any changes made.

        3. Ensure that the correct columns and values are used based on the schema provided. Verify the query against the schema to confirm accuracy.

        4. Include column name headers in the query results for clarity.



        Always provide your answer in the JSON format below:



        {{ "summary": "your-summary", "query":  "your-query" }}

        

        Output ONLY JSON.

        In the preceding JSON response, substitute "your-query" with a MariaDB query to retrieve the requested data.

        In the preceding JSON response, substitute "your-summary" with a summary of the query and any corrections or clarifications made.

        Always include all columns in the table.

        """

        prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
        output = pipe(prompt, max_new_tokens=100)
        print("Generated output:", output)  # Debugging: Print the generated output

        generated_text = output[0]['generated_text']
        sql_query = generated_text.split("SQL query:")[-1].strip()
        
        if not sql_query.lower().startswith(('select', 'show', 'describe')):
            raise ValueError("Generated text is not a valid SQL query")

        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute(sql_query)
        results = cursor.fetchall()

        cursor.close()
        conn.close()

        return {"sql": sql_query, "results": results}
    except Exception as e:
        print("Error occurred:", str(e))  # Debugging: Print the error
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)