|
import gradio as gr |
|
from mistralai.client import MistralClient |
|
from mistralai.models.chat_completion import ChatMessage |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
|
|
def chat_with_mistral(source_cols, dest_col, prompt, tdoc_name, excel_file, url): |
|
|
|
df = pd.read_excel(excel_file) |
|
api_key = os.environ["MISTRAL_API_KEY"] |
|
model = "mistral-small" |
|
|
|
|
|
client = MistralClient(api_key=api_key) |
|
|
|
source_columns = source_cols |
|
df[dest_col] = "" |
|
try: |
|
file_name = url.split("/")[-2] + ".xlsx" |
|
except: |
|
file_name = excel_file |
|
|
|
if tdoc_name != '': |
|
filtered_df = df[df['File'] == tdoc_name] |
|
if not filtered_df.empty: |
|
concatenated_content = "\n\n".join(f"{column_name}: {filtered_df[column_name].iloc[0]}" for column_name in source_columns) |
|
messages = [ChatMessage(role="user", content=f"Using the following content: {concatenated_content}"), ChatMessage(role="user", content=prompt)] |
|
chat_response = client.chat(model=model, messages=messages) |
|
filtered_df.loc[filtered_df.index[0], dest_col] = chat_response.choices[0].message.content |
|
|
|
df.update(filtered_df) |
|
|
|
df.to_excel(file_name, index=False) |
|
return file_name, df.head(5) |
|
else: |
|
return file_name, df.head(5) |
|
else: |
|
for index, row in df.iterrows(): |
|
concatenated_content = "\n\n".join(f"{column_name}: {row[column_name]}" for column_name in source_columns) |
|
|
|
print('test') |
|
if not concatenated_content == "\n\n".join(f"{column_name}: nan" for column_name in source_columns): |
|
print('c bon') |
|
messages = [ChatMessage(role="user", content=f"Using the following content: {concatenated_content}"), ChatMessage(role="user", content=prompt)] |
|
chat_response = client.chat(model=model, messages=messages) |
|
df.at[index, dest_col] = chat_response.choices[0].message.content |
|
|
|
df.to_excel(file_name, index=False) |
|
return file_name, df.head(5) |
|
|
|
|
|
def get_columns(file): |
|
if file is not None: |
|
df = pd.read_excel(file) |
|
columns = list(df.columns) |
|
return gr.update(choices=columns), gr.update(choices=columns), gr.update(choices=columns), gr.update(choices=columns + [""]), df.head(5) |
|
else: |
|
return gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), pd.DataFrame() |