PCFISH's picture
Update app.py
af69459
raw
history blame
No virus
7.75 kB
import streamlit as st
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
from langchain.vectorstores import FAISS, Chroma
from langchain.embeddings import HuggingFaceEmbeddings # General embeddings from HuggingFace models.
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from htmlTemplates import css, bot_template, user_template
from langchain.llms import HuggingFaceHub, LlamaCpp, CTransformers # For loading transformer models.
from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
import tempfile # μž„μ‹œ νŒŒμΌμ„ μƒμ„±ν•˜κΈ° μœ„ν•œ λΌμ΄λΈŒλŸ¬λ¦¬μž…λ‹ˆλ‹€.
import os
# PDF λ¬Έμ„œλ‘œλΆ€ν„° ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
def get_pdf_text(pdf_docs):
temp_dir = tempfile.TemporaryDirectory() # μž„μ‹œ 디렉토리λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name) # μž„μ‹œ 파일 경둜λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
with open(temp_filepath, "wb") as f: # μž„μ‹œ νŒŒμΌμ„ λ°”μ΄λ„ˆλ¦¬ μ“°κΈ° λͺ¨λ“œλ‘œ μ—½λ‹ˆλ‹€.
f.write(pdf_docs.getvalue()) # PDF λ¬Έμ„œμ˜ λ‚΄μš©μ„ μž„μ‹œ νŒŒμΌμ— μ”λ‹ˆλ‹€.
pdf_loader = PyPDFLoader(temp_filepath) # PyPDFLoaderλ₯Ό μ‚¬μš©ν•΄ PDFλ₯Ό λ‘œλ“œν•©λ‹ˆλ‹€.
pdf_doc = pdf_loader.load() # ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•©λ‹ˆλ‹€.
return pdf_doc # μΆ”μΆœν•œ ν…μŠ€νŠΈλ₯Ό λ°˜ν™˜ν•©λ‹ˆλ‹€.
# 과제
# μ•„λž˜ ν…μŠ€νŠΈ μΆ”μΆœ ν•¨μˆ˜λ₯Ό μž‘μ„±
def get_text_file(docs):
if docs.type == 'text/plain':
# ν…μŠ€νŠΈ 파일 (.txt)μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜
return [docs.getvalue().decode('utf-8')]
else:
st.warning("Unsupported file type for get_text_file")
def get_csv_file(docs):
if docs.type == 'text/csv':
# CSV 파일 (.csv)μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜
csv_loader = CSVLoader(docs)
csv_data = csv_loader.load()
# CSV 파일의 각 행을 λ¬Έμžμ—΄λ‘œ λ³€ν™˜ν•˜μ—¬ λ°˜ν™˜
return [' '.join(map(str, row)) for row in csv_data]
else:
st.warning("Unsupported file type for get_csv_file")
def get_json_file(docs):
if docs.type == 'application/json':
# JSON 파일 (.json)μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜
json_loader = JSONLoader(docs)
json_data = json_loader.load()
# JSON 파일의 각 ν•­λͺ©μ„ λ¬Έμžμ—΄λ‘œ λ³€ν™˜ν•˜μ—¬ λ°˜ν™˜
return [json.dumps(item) for item in json_data]
else:
st.warning("Unsupported file type for get_json_file")
# λ¬Έμ„œλ“€μ„ μ²˜λ¦¬ν•˜μ—¬ ν…μŠ€νŠΈ 청크둜 λ‚˜λˆ„λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
def get_text_chunks(documents):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
# 각 λ¬Έμ„œμ˜ λ‚΄μš©μ„ λ¦¬μŠ€νŠΈμ— μΆ”κ°€
texts = []
for doc in documents:
if hasattr(doc, 'page_content'):
# λ¬Έμ„œ 객체인 κ²½μš°μ—λ§Œ μΆ”κ°€
texts.append(doc.page_content)
elif isinstance(doc, str):
# λ¬Έμžμ—΄μΈ 경우 κ·ΈλŒ€λ‘œ μΆ”κ°€
texts.append(doc)
# λ‚˜λˆˆ 청크λ₯Ό λ°˜ν™˜
return text_splitter.split_documents(texts)
# ν…μŠ€νŠΈ μ²­ν¬λ“€λ‘œλΆ€ν„° 벑터 μŠ€ν† μ–΄λ₯Ό μƒμ„±ν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
def get_vectorstore(text_chunks):
# OpenAI μž„λ² λ”© λͺ¨λΈμ„ λ‘œλ“œν•©λ‹ˆλ‹€. (Embedding models - Ada v2)
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(text_chunks, embeddings) # FAISS 벑터 μŠ€ν† μ–΄λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
return vectorstore # μƒμ„±λœ 벑터 μŠ€ν† μ–΄λ₯Ό λ°˜ν™˜ν•©λ‹ˆλ‹€.
def get_conversation_chain(vectorstore):
print(f"DEBUG: session_state.conversation before initialization: {st.session_state.conversation}")
try:
if st.session_state.conversation is None:
gpt_model_name = 'gpt-3.5-turbo'
llm = ChatOpenAI(model_name=gpt_model_name)
# λŒ€ν™” 기둝을 μ €μž₯ν•˜κΈ° μœ„ν•œ λ©”λͺ¨λ¦¬λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
memory = ConversationBufferMemory(
memory_key='chat_history', return_messages=True)
# λŒ€ν™” 검색 체인을 μƒμ„±ν•©λ‹ˆλ‹€.
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory
)
st.session_state.conversation = conversation_chain
except Exception as e:
print(f"Error during conversation initialization: {e}")
print(f"DEBUG: session_state.conversation after initialization: {st.session_state.conversation}")
return st.session_state.conversation if st.session_state.conversation else ConversationalRetrievalChain()
# μ‚¬μš©μž μž…λ ₯을 μ²˜λ¦¬ν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
def handle_userinput(user_question):
# λŒ€ν™” 체인을 μ‚¬μš©ν•˜μ—¬ μ‚¬μš©μž μ§ˆλ¬Έμ— λŒ€ν•œ 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
response = st.session_state.conversation({'question': user_question})
# λŒ€ν™” 기둝을 μ €μž₯ν•©λ‹ˆλ‹€.
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(user_template.replace(
"{{MSG}}", message.content), unsafe_allow_html=True)
else:
st.write(bot_template.replace(
"{{MSG}}", message.content), unsafe_allow_html=True)
def main():
load_dotenv()
st.set_page_config(page_title="Chat with multiple Files :)",
page_icon=":books:")
st.write(css, unsafe_allow_html=True)
if "conversation" not in st.session_state or st.session_state.conversation is None:
st.session_state.conversation = None
st.session_state.chat_history = None
st.header("Chat with multiple Files :")
user_question = st.text_input("Ask a question about your documents:")
if user_question:
handle_userinput(user_question)
with st.sidebar:
openai_key = st.text_input("Paste your OpenAI API key (sk-...)")
if openai_key:
os.environ["OPENAI_API_KEY"] = openai_key
st.subheader("Your documents")
docs = st.file_uploader(
"Upload your documents here and click on 'Process'", accept_multiple_files=True)
if st.button("Process"):
with st.spinner("Processing"):
# λ¬Έμ„œμ—μ„œ μΆ”μΆœν•œ ν…μŠ€νŠΈλ₯Ό 담을 리슀트
doc_list = []
for file in docs:
if file.type == 'text/plain':
# .txt 파일의 경우
doc_list.extend(get_text_file(file))
elif file.type == 'text/csv':
# .csv 파일의 경우
doc_list.extend(get_csv_file(file))
elif file.type == 'application/json':
# .json 파일의 경우
doc_list.extend(get_json_file(file))
elif file.type in ['application/octet-stream', 'application/pdf']:
# .pdf 파일의 경우
doc_list.extend(get_pdf_text(file))
# ν…μŠ€νŠΈ 청크둜 λ‚˜λˆ„κΈ°
text_chunks = get_text_chunks(doc_list)
# 벑터 μŠ€ν† μ–΄ 생성
vectorstore = get_vectorstore(text_chunks)
# λŒ€ν™” 체인 생성
st.session_state.conversation = get_conversation_chain(vectorstore)
if __name__ == '__main__':
main()