Spaces:
Running
Running
from langchain.chains import create_sql_query_chain | |
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline, LlamaTokenizer, LlamaForCausalLM | |
from langchain_huggingface import HuggingFacePipeline | |
from langchain_openai import ChatOpenAI | |
import os | |
from langchain_community.utilities.sql_database import SQLDatabase | |
from operator import itemgetter | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.example_selectors import SemanticSimilarityExampleSelector | |
from langchain_openai import OpenAIEmbeddings | |
from operator import itemgetter | |
from langchain.chains.openai_tools import create_extraction_chain_pydantic | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from typing import List | |
import pandas as pd | |
from argparse import ArgumentParser | |
import json | |
from langchain.memory import ChatMessageHistory | |
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
import subprocess | |
import sys | |
from transformers import pipeline | |
import librosa | |
import soundfile | |
import datasets | |
# import sounddevice as sd | |
import numpy as np | |
import io | |
import gradio as gr | |
model_id = "avnishkanungo/whisper-small-dv" # update with your model id | |
pipe = pipeline("automatic-speech-recognition", model=model_id) | |
def sql_translator(filepath, key): | |
def select_table(desc_path): | |
def get_table_details(): | |
# Read the CSV file into a DataFrame | |
table_description = pd.read_csv(desc_path) ##"/teamspace/studios/this_studio/database_table_descriptions.csv" | |
table_docs = [] | |
# Iterate over the DataFrame rows to create Document objects | |
table_details = "" | |
for index, row in table_description.iterrows(): | |
table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n" | |
return table_details | |
class Table(BaseModel): | |
"""Table in SQL database.""" | |
name: str = Field(description="Name of table in SQL database.") | |
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \ | |
The tables are: | |
{get_table_details()} | |
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.""" | |
table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | |
def get_tables(tables: List[Table]) -> List[str]: | |
tables = [table.name for table in tables] | |
return tables | |
select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables | |
return select_table | |
def prompt_creation(example_path): | |
with open(example_path, 'r') as file: ##'/teamspace/studios/this_studio/few_shot_samples.json' | |
data = json.load(file) | |
examples = data["examples"] | |
example_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("human", "{input}\nSQLQuery:"), | |
("ai", "{query}"), | |
] | |
) | |
vectorstore = Chroma() | |
vectorstore.delete_collection() | |
example_selector = SemanticSimilarityExampleSelector.from_examples( | |
examples, | |
OpenAIEmbeddings(), | |
vectorstore, | |
k=2, | |
input_keys=["input"], | |
) | |
few_shot_prompt = FewShotChatMessagePromptTemplate( | |
example_prompt=example_prompt, | |
example_selector=example_selector, | |
input_variables=["input","top_k"], | |
) | |
final_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."), | |
few_shot_prompt, | |
MessagesPlaceholder(variable_name="messages"), | |
("human", "{input}"), | |
] | |
) | |
print(few_shot_prompt.format(input="How many products are there?")) | |
return final_prompt | |
def rephrase_answer(): | |
answer_prompt = PromptTemplate.from_template( | |
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question. | |
Question: {question} | |
SQL Query: {query} | |
SQL Result: {result} | |
Answer: """ | |
) | |
rephrase_answer = answer_prompt | llm | StrOutputParser() | |
return rephrase_answer | |
def is_ffmpeg_installed(): | |
try: | |
# Run `ffmpeg -version` to check if ffmpeg is installed | |
subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
return True | |
except (subprocess.CalledProcessError, FileNotFoundError): | |
return False | |
def install_ffmpeg(): | |
try: | |
if sys.platform.startswith('linux'): | |
subprocess.run(['sudo', 'apt-get', 'update'], check=True) | |
subprocess.run(['sudo', 'apt-get', 'install', '-y', 'ffmpeg'], check=True) | |
elif sys.platform == 'darwin': # macOS | |
subprocess.run(['/bin/bash', '-c', 'brew install ffmpeg'], check=True) | |
elif sys.platform == 'win32': | |
print("Please download ffmpeg from https://ffmpeg.org/download.html and install it manually.") | |
return False | |
else: | |
print("Unsupported OS. Please install ffmpeg manually.") | |
return False | |
except subprocess.CalledProcessError as e: | |
print(f"Failed to install ffmpeg: {e}") | |
return False | |
return True | |
def transcribe_speech(filepath): | |
output = pipe( | |
filepath, | |
max_new_tokens=256, | |
generate_kwargs={ | |
"task": "transcribe", | |
"language": "english", | |
}, # update with the language you've fine-tuned on | |
chunk_length_s=30, | |
batch_size=8, | |
) | |
return output["text"] | |
# def record_command(): | |
# sample_rate = 16000 # Sample rate in Hz | |
# duration = 8 # Duration in seconds | |
# print("Recording...") | |
# # Record audio | |
# audio = sd.rec(int(sample_rate * duration), samplerate=sample_rate, channels=1, dtype='float32') | |
# sd.wait() # Wait until recording is finished | |
# print("Recording finished") | |
# # Convert the audio to a binary stream and save it to a variable | |
# audio_buffer = io.BytesIO() | |
# soundfile.write(audio_buffer, audio, sample_rate, format='WAV') | |
# audio_buffer.seek(0) # Reset buffer position to the beginning | |
# # The audio file is now saved in audio_buffer | |
# # You can read it again using soundfile or any other audio library | |
# audio_data, sample_rate = soundfile.read(audio_buffer) | |
# # Optional: Save the audio to a file for verification | |
# # with open('recorded_audio.wav', 'wb') as f: | |
# # f.write(audio_buffer.getbuffer()) | |
# print("Audio saved to variable") | |
# return audio_data | |
def check_libportaudio_installed(): | |
try: | |
# Run `ffmpeg -version` to check if ffmpeg is installed | |
subprocess.run(['libportaudio2', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
return True | |
except (subprocess.CalledProcessError, FileNotFoundError): | |
return False | |
def install_libportaudio(): | |
try: | |
if sys.platform.startswith('linux'): | |
subprocess.run(['sudo', 'apt-get', 'update'], check=True) | |
subprocess.run(['sudo', 'apt-get', 'install', '-y', 'libportaudio2'], check=True) | |
elif sys.platform == 'darwin': # macOS | |
subprocess.run(['/bin/bash', '-c', 'brew install portaudio'], check=True) | |
elif sys.platform == 'win32': | |
print("Please download ffmpeg from https://ffmpeg.org/download.html and install it manually.") | |
return False | |
else: | |
print("Unsupported OS. Please install ffmpeg manually.") | |
return False | |
except subprocess.CalledProcessError as e: | |
print(f"Failed to install ffmpeg: {e}") | |
return False | |
return True | |
db_user = "admin" | |
db_password = "avnishk96" | |
db_host = "demo-db.cdm44iseol25.us-east-1.rds.amazonaws.com" | |
db_name = "classicmodels" | |
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}") | |
# print(db.dialect) | |
# print(db.get_usable_table_names()) | |
# print(db.table_info) | |
os.environ["OPENAI_API_KEY"] = key | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
history = ChatMessageHistory() | |
final_prompt = prompt_creation(os.getcwd()+"/few_shot_samples.json") | |
generate_query = create_sql_query_chain(llm, db, final_prompt) | |
execute_query = QuerySQLDataBaseTool(db=db) | |
# if is_ffmpeg_installed(): | |
# print("ffmpeg is already installed.") | |
# else: | |
# print("ffmpeg is not installed. Installing ffmpeg...") | |
# if install_ffmpeg(): | |
# print("ffmpeg installation successful.") | |
# else: | |
# print("ffmpeg installation failed. Please install it manually.") | |
# if check_libportaudio_installed(): | |
# print("libportaudio is already installed.") | |
# else: | |
# print("libportaudio is not installed. Installing ffmpeg...") | |
# if install_libportaudio(): | |
# print("libportaudio installation successful.") | |
# else: | |
# print("libportaudio installation failed. Please install it manually.") | |
if os.path.isfile(filepath): | |
sql_query = transcribe_speech(filepath) | |
else: | |
sql_query = filepath | |
# sql_query = transcribe_speech(filepath) | |
chain = ( | |
RunnablePassthrough.assign(table_names_to_use=select_table(os.getcwd()+"/database_table_descriptions.csv")) | | |
RunnablePassthrough.assign(query=generate_query).assign( | |
result=itemgetter("query") | execute_query | |
) | |
| rephrase_answer() | |
) | |
output = chain.invoke({"question": sql_query, "messages":history.messages}) | |
history.add_user_message(sql_query) | |
history.add_ai_message(output) | |
return output | |
def create_interface(): | |
demo = gr.Blocks() | |
mic_transcribe = gr.Interface( | |
fn=sql_translator, | |
# key_input = gr.Textbox(lines=2, placeholder="Enter text here...", label="Open AI Key"), | |
# audio_input = gr.Audio(sources="microphone", type="filepath"), | |
inputs = [gr.Audio(sources="microphone", type="filepath"),gr.Textbox(lines=2, placeholder="Enter text here...", label="Open AI Key")], | |
outputs=gr.components.Textbox(), | |
) | |
file_transcribe = gr.Interface( | |
fn=sql_translator, | |
# key_input = gr.Textbox(lines=2, placeholder="Enter text here...", label="Open AI Key"), | |
# query_input = gr.Textbox(lines=2, placeholder="Enter text here...", label="Input Text..."), | |
inputs = [gr.Textbox(lines=2, placeholder="Enter text here...", label="Input Text...") ,gr.Textbox(lines=2, placeholder="Enter text here...", label="Open AI Key")], | |
# inputs=gr.Audio(sources="upload", type="filepath"), | |
outputs=gr.components.Textbox(), | |
) | |
with demo: | |
gr.TabbedInterface( | |
[mic_transcribe, file_transcribe], | |
["Audio Query", "Text Query"], | |
) | |
demo.launch(share=True) | |
# return interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
# interface.launch(debug=True) |