Standard_Intelligence_Dev / excel_chat.py
MaksG's picture
Update excel_chat.py
e949ec2 verified
raw
history blame
4.96 kB
import gradio as gr
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import os
import pandas as pd
import numpy as np
from groq import Groq
import anthropic
from users_management import update_json, users
#users = ['maksG', 'Alma', 'YchK']
def ask_llm(query, input, client_index):
messages = [
{
"role": "system",
"content": f"You are a helpful assistant. Only show your final response to the **User Query**! Do not provide any explanations or details: \n# User Query:\n{query}."
},
{
"role": "user",
"content": f"{input}",
}
]
systemC = f"You are a helpful assistant. Only show your final response to the **User Query**! Do not provide any explanations or details: \n# User Query:\n{query}."
messageC=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"{input}"
}
]
}
]
if client_index == "Groq":
client = Groq(api_key=os.environ["GROQ_API_KEY"])
chat_completion = client.chat.completions.create(
messages=messages,
model='mixtral-8x7b-32768',
)
elif client_index == "Mistral Small":
client = MistralClient(api_key=os.environ['MISTRAL_API_KEY'])
chat_completion = client.chat(
messages=messages,
model='mistral-small-latest',
)
elif client_index == "Mistral Tiny":
client = MistralClient(api_key=os.environ['MISTRAL_API_KEY'])
chat_completion = client.chat(
messages=messages,
model='mistral-tiny',
)
elif client_index == "Mistral Medium":
client = MistralClient(api_key=os.environ['MISTRAL_API_KEY'])
chat_completion = client.chat(
messages=messages,
model='mistral-medium',
)
elif client_index == "Claude Opus":
client = anthropic.Anthropic(api_key=os.environ['CLAUDE_API_KEY'])
chat_completion = client.messages.create(
model="claude-3-opus-20240229",
max_tokens=350,
temperature=0,
system=systemC,
messages=messageC
).content[0].text
return chat_completion
else:
client = anthropic.Anthropic(api_key=os.environ['CLAUDE_API_KEY'])
chat_completion = client.messages.create(
model="claude-3-sonnet-20240229",
max_tokens=350,
temperature=0,
system=systemC,
messages=messageC
).content[0].text
return chat_completion
return chat_completion.choices[0].message.content
def filter_df(df, column_name, keywords):
if len(keywords)>0:
if column_name in df.columns:
contains_keyword = lambda x: any(keyword.lower() in (x.lower() if type(x)==str else '') for keyword in keywords)
filtered_df = df[df[column_name].apply(contains_keyword)]
else:
contains_keyword = lambda row: any(keyword.lower() in (str(cell).lower() if isinstance(cell, str) else '') for keyword in keywords for cell in row)
filtered_df = df[df.apply(contains_keyword, axis=1)]
else:
filtered_df = df
return filtered_df
def chat_with_mistral(source_cols, dest_col, prompt, excel_file, url, search_col, keywords, client, user):
update_json(user, prompt, keywords)
print(f'xlsxfile = {excel_file}')
df = pd.read_excel(excel_file)
df[dest_col] = ""
try:
file_name = url.split("/")[-2] + ".xlsx"
except:
file_name = excel_file
print(f"Keywords: {keywords}")
filtred_df = filter_df(df, search_col, keywords)
for index, row in filtred_df.iterrows():
concatenated_content = "\n\n".join(f"{column_name}: {str(row[column_name])}" for column_name in source_cols)
print('test')
if not concatenated_content == "\n\n".join(f"{column_name}: nan" for column_name in source_cols):
print('c bon')
llm_answer = ask_llm(prompt[0], concatenated_content, client)
print(f"QUERY:\n{prompt[0]}\nCONTENT:\n{concatenated_content[:200]}...\n\nANSWER:\n{llm_answer}")
df.at[index, dest_col] = llm_answer
df.to_excel(file_name, index=False)
return file_name, df.head(5)
def get_columns(file,progress=gr.Progress()):
if file is not None:
df = pd.read_excel(file)
columns = list(df.columns)
return gr.update(choices=columns), gr.update(choices=columns), gr.update(choices=columns), gr.update(choices=columns + [""]), gr.update(choices=columns + ['[ALL]']), df.head(5)
else:
return gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), pd.DataFrame()