projectfoodie / retriever.py
Bunpheng's picture
test
639d8f7
raw
history blame contribute delete
No virus
3.39 kB
import pinecone
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_pinecone import PineconeVectorStore
from tqdm import tqdm
from pinecone import Pinecone
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_pinecone import PineconeVectorStore
from concurrent.futures import ThreadPoolExecutor
import gradio as gr
from langchain_nomic import NomicEmbeddings
from transformers import AutoModel
from gradio_app.backend.embedders import EmbedderFactory
# import requests
class Retrieving:
def __init__(self, apikey, index_name):
self.prompts = self.setup_prompt()
self.pinecone_api_key = apikey
self.index_name = index_name
# self.llm = Ollama(model="llama3")
# self.embeddings = OllamaEmbeddings(model="nomic-embed-text")
# self.embeddings = AutoModel.from_pretrained("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
# self.embeddings = OllamaEmbeddings(model="nomic-embed-text")
self.llm = gr.load("models/mistralai/Mistral-7B-Instruct-v0.3")
# self.embeddings = AutoModel.from_pretrained("models/nomic-ai/nomic-bert-2048", trust_remote_code=True)
# self.embeddings = EmbedderFactory.get_embedder(model="nomic-embed-text")
self.embeddings = NomicEmbeddings()
self.vectorstore = self.load_vectorstore()
self.retriever = self.create_retriever()
# def setup_llm(self, payload):
# API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-large"
# response = requests.post(API_URL, headers=headers, json=payload)
# return response.json()
def load_vectorstore(self):
tqdm.write("Established Pinecone Connection")
pc = PineconeVectorStore(index_name=self.index_name, embedding=self.embeddings, pinecone_api_key= self.pinecone_api_key)
return pc
def create_retriever(self):
tqdm.write("Creating retriever")
return self.vectorstore.as_retriever()
def setup_prompt(self):
template = """
You are an assistant specialized in providing restaurant information from the retrieved documents. You only allow to give answer for restaruant in the state of florida.
Use the retrieved information to answer the question accurately. If you don't know the answer based on the provided context,
simply respond that you don't know. Do not make up any information. You should act on behalf of my company and say the word "I" instead of based on the provided documentation.
Context: {context}
Question: {question}
"""
tqdm.write("Prompt setup completed")
return ChatPromptTemplate.from_template(template)
def get_response(self, user_input):
chain = (
RunnableParallel({"context": self.retriever, "question": RunnablePassthrough()})
| self.prompts
| self.llm
| StrOutputParser()
)
# return chain.invoke(user_input)
result = chain.invoke(user_input)
# print("Result type:", type(result))
# print("Result content:", result)
return result