Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import theme | |
theme = theme.Theme() | |
import os | |
import sys | |
sys.path.append('../..') | |
#langchain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import StrOutputParser | |
from langchain.schema.runnable import Runnable | |
from langchain.schema.runnable.config import RunnableConfig | |
from langchain.chains import ( | |
LLMChain, ConversationalRetrievalChain) | |
from langchain.vectorstores import Chroma | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import LLMChain | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate | |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
from langchain.document_loaders import PyPDFDirectoryLoader | |
from langchain_community.llms import HuggingFaceHub | |
from pydantic import BaseModel | |
import shutil | |
custom_title = "<span style='color: rgb(18, 13, 5);'>Green Greta</span>" | |
# Cell 1: Image Classification Model | |
image_pipeline = pipeline(task="image-classification", model="guillen/vit-basura-test1") | |
def predict_image(input_img): | |
predictions = image_pipeline(input_img) | |
return {p["label"]: p["score"] for p in predictions} | |
image_gradio_app = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(label="Image", sources=['upload', 'webcam'], type="pil"), | |
outputs=[gr.Label(label="Result")], | |
title=custom_title, | |
theme=theme | |
) | |
# Cell 2: Chatbot Model | |
loader = PyPDFDirectoryLoader('pdfs') | |
data=loader.load() | |
# split documents | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=70, | |
length_function=len | |
) | |
docs = text_splitter.split_documents(data) | |
# define embedding | |
embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-small') | |
# create vector database from data | |
persist_directory = 'docs/chroma/' | |
# Remove old database files if any | |
shutil.rmtree(persist_directory, ignore_errors=True) | |
vectordb = Chroma.from_documents( | |
documents=docs, | |
embedding=embeddings, | |
persist_directory=persist_directory | |
) | |
# define retriever | |
retriever = vectordb.as_retriever(search_type="mmr") | |
template = """ | |
Your name is Greta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish / | |
Use the following pieces of context to answer the question if the question is related with recycling / | |
No more than two chunks of context / | |
Answer in the same language of the question / | |
Always say "thanks for asking!" at the end of the answer / | |
If the context is not relevant, please answer the question by using your own knowledge about the topic. | |
context: {context} | |
question: {question} | |
""" | |
# Create the chat prompt templates | |
system_prompt = SystemMessagePromptTemplate.from_template(template) | |
qa_prompt = ChatPromptTemplate( | |
messages=[ | |
system_prompt, | |
MessagesPlaceholder(variable_name="chat_history"), | |
HumanMessagePromptTemplate.from_template("{question}") | |
] | |
) | |
llm = HuggingFaceHub( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
task="text-generation", | |
model_kwargs={ | |
"max_new_tokens": 1024, | |
"top_k": 30, | |
"temperature": 0.1, | |
"repetition_penalty": 1.03, | |
}, | |
) | |
memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='answer', return_messages=True) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm = llm, | |
memory = memory, | |
retriever = retriever, | |
verbose = True, | |
combine_docs_chain_kwargs={'prompt': qa_prompt}, | |
get_chat_history = lambda h : h, | |
rephrase_question = False, | |
output_key = 'answer' | |
) | |
def chat_interface(question,history): | |
result = qa_chain.invoke({"question": question}) | |
return result['answer'] # If the result is a string, return it directly | |
chatbot_gradio_app = gr.ChatInterface( | |
fn=chat_interface, | |
title='Green Greta' | |
) | |
# Combine both interfaces into a single app | |
gr.TabbedInterface( | |
[image_gradio_app, chatbot_gradio_app], | |
tab_names=["Green Greta Image Classification","Green Greta Chat"], | |
theme=theme | |
).launch() |