LexAIcon / app.py
manuelcozar55's picture
Update app.py
a026832 verified
raw
history blame contribute delete
No virus
7.9 kB
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from PyPDF2 import PdfReader
from docx import Document
import csv
import json
import torch
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from huggingface_hub import login
# Autenticaci贸n en Hugging Face
huggingface_token = st.secrets["HUGGINGFACE_TOKEN"]
login(huggingface_token)
# Configurar modelo y tokenizador
model_name = 'Qwen/Qwen2-1.5B'
model_config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
text_generation_pipeline = pipeline(
model=model_name,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=1000,
)
prompt_template = """
### [INST]
Instruction: Answer the question based on your knowledge. Here is context to help:
{context}
### QUESTION:
{question}
[/INST]
"""
mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# Crear el prompt desde la plantilla de prompt
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template,
)
# Crear la cadena LLM
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)
# Funci贸n para manejar archivos subidos
def handle_uploaded_file(uploaded_file):
try:
if uploaded_file.name.endswith(".txt"):
text = uploaded_file.read().decode("utf-8")
elif uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
text = ""
for page in range(len(reader.pages)):
text += reader.pages[page].extract_text()
elif uploaded_file.name.endswith(".docx"):
doc = Document(uploaded_file)
text = "\n".join([para.text for para in doc.paragraphs])
elif uploaded_file.name.endswith(".csv"):
text = ""
content = uploaded_file.read().decode("utf-8").splitlines()
reader = csv.reader(content)
text = " ".join([" ".join(row) for row in reader])
elif uploaded_file.name.endswith(".json"):
data = json.load(uploaded_file)
text = json.dumps(data, indent=4)
else:
text = "Tipo de archivo no soportado."
return text
except Exception as e:
return str(e)
# Funci贸n para traducir texto
def translate(text, target_language):
context = ""
question = f"Por favor, traduzca el siguiente documento al {target_language}:\n{text}\nAseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento."
response = llm_chain.run(context=context, question=question)
return response
# Funci贸n para resumir texto
def summarize(text, length):
context = ""
question = f"Por favor, haga un resumen {length} del siguiente documento:\n{text}\nAseg煤rese de que el resumen sea conciso y conserve el significado original del documento."
response = llm_chain.run(context=context, question=question)
return response
# Configuraci贸n del modelo de clasificaci贸n
@st.cache_resource
def load_classification_model():
tokenizer_cls = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
model_cls = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
return model_cls, tokenizer_cls
classification_model, classification_tokenizer = load_classification_model()
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"}
def classify_text(text):
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length")
classification_model.eval()
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax(dim=-1).item()
predicted_label = id2label[predicted_class_id]
return predicted_label
# Funci贸n para cargar documentos JSON
def load_json_documents(category):
try:
with open(f"./{category}.json", "r", encoding="utf-8") as f:
data = json.load(f)["questions_and_answers"]
documents = [entry["question"] + " " + entry["answer"] for entry in data]
return documents
except FileNotFoundError:
return []
# Configuraci贸n de FAISS y embeddings
@st.cache_resource
def create_vector_store(docs):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"})
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
split_docs = text_splitter.split_text(docs)
vector_store = FAISS.from_texts(split_docs, embeddings)
return vector_store
def explain_text(user_input, document_text):
classification = classify_text(document_text)
if classification in ["multas", "politicas_de_privacidad", "contratos", "denuncias"]:
docs = load_json_documents(classification)
if docs:
vector_store = create_vector_store(docs)
search_docs = vector_store.similarity_search(user_input)
context = " ".join([doc.page_content for doc in search_docs])
else:
context = ""
else:
context = ""
question = user_input
response = llm_chain.run(context=context, question=question)
return response
def main():
st.title("LexAIcon")
st.write("Puedes conversar con este chatbot basado en Mistral-7B-Instruct y subir archivos para que el chatbot los procese.")
with st.sidebar:
st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)")
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"])
if operation == "Explicar":
user_input = st.text_area("Introduce tu pregunta:", "")
uploaded_file = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"])
if uploaded_file and user_input:
document_text = handle_uploaded_file(uploaded_file)
bot_response = explain_text(user_input, document_text)
st.write(f"**Assistant:** {bot_response}")
else:
uploaded_file = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"])
if uploaded_file:
document_text = handle_uploaded_file(uploaded_file)
if operation == "Traducir":
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"])
if target_language:
bot_response = translate(document_text, target_language)
st.write(f"**Assistant:** {bot_response}")
elif operation == "Resumir":
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"])
if summary_length:
if summary_length == "corto":
length = "de aproximadamente 50 palabras"
elif summary_length == "medio":
length = "de aproximadamente 100 palabras"
elif summary_length == "largo":
length = "de aproximadamente 500 palabras"
bot_response = summarize(document_text, length)
st.write(f"**Assistant:** {bot_response}")
if __name__ == "__main__":
main()