|
import os |
|
import streamlit as st |
|
from transformers import AutoTokenizer |
|
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions |
|
from langchain_community.utilities import SQLDatabase |
|
from langchain.llms import HuggingFaceHub |
|
from langchain.chains import create_sql_query_chain |
|
|
|
@st.cache_resource(show_spinner="Connecting...") |
|
def getSession(): |
|
pars = SnowflakeLoginOptions("test_conn") |
|
pars["account"] ="ap20346.ap-south-1" |
|
pars["user"] = "Vassist" |
|
pars["password"]= "Vassist@123" |
|
pars["role"]= "ACCOUNTADMIN" |
|
pars["warehouse"] = "COMPUTE_WH" |
|
pars["database"] = "SNOWFLAKE_SAMPLE_DATA" |
|
pars["schema"] = "TPCH_SF1" |
|
session = Session.builder.configs(pars).create() |
|
|
|
url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}" |
|
+ f"/{pars['database']}/{pars['schema']}" |
|
+ f"?warehouse={pars['warehouse']}&role={pars['role']}") |
|
db = SQLDatabase.from_uri(url) |
|
|
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] ='HUGGINGFACEHUB_API_TOKEN' |
|
llm = HuggingFaceHub(repo_id= "ravithejakandi/Sf-Arctic-Demo-Enu") |
|
chain = create_sql_query_chain(llm, db) |
|
return session, db, chain |
|
|
|
|
|
st.title("SQL Query Generator") |
|
st.write("Returns and runs queries from questions in natural language.") |
|
|
|
session, db, chain = getSession() |
|
|
|
question = st.sidebar.text_area("Ask a question:", |
|
value="Show me the total number of entries in the first table") |
|
sql = chain.invoke({"question": question}).rstrip(';') |
|
|
|
tabQuery, tabData, tabLog = st.tabs(["Query", "Data", "Log"]) |
|
tabQuery.code(sql, language="sql") |
|
tabData.dataframe(session.sql(sql)) |
|
tabLog.code(db.table_info, language="sql") |