|
import streamlit as st |
|
import tiktoken |
|
from loguru import logger |
|
|
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chat_models import ChatOpenAI |
|
|
|
from langchain.document_loaders.pdf import (PyPDFLoader, PyMuPDFLoader) |
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.vectorstores import FAISS |
|
|
|
|
|
from langchain.callbacks import get_openai_callback |
|
from langchain.memory import StreamlitChatMessageHistory |
|
from gtts import gTTS |
|
from IPython.display import Audio, display |
|
|
|
from pydub import AudioSegment |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="์ฐจ๋์ฉ Q&A ์ฑ๋ด", |
|
page_icon=":car:") |
|
|
|
st.title("์ฐจ๋์ฉ Q&A ์ฑ๋ด :car:") |
|
|
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation = None |
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = None |
|
|
|
if "processComplete" not in st.session_state: |
|
st.session_state.processComplete = None |
|
|
|
with st.sidebar: |
|
uploaded_files = st.file_uploader("์ฐจ๋ ๋ฉ๋ด์ผ PDF ํ์ผ์ ๋ฃ์ด์ฃผ์ธ์.", type=['pdf'], accept_multiple_files=True) |
|
openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password") |
|
process = st.button("์คํ") |
|
|
|
if process: |
|
if not openai_api_key: |
|
st.info("Open AIํค๋ฅผ ์
๋ ฅํด์ฃผ์ธ์.") |
|
st.stop() |
|
files_text = get_text(uploaded_files) |
|
text_chunks = get_text_chunks(files_text) |
|
vetorestore = get_vectorstore(text_chunks) |
|
|
|
st.session_state.conversation = get_conversation_chain(vetorestore, openai_api_key) |
|
|
|
st.session_state.processComplete = True |
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state['messages'] = [{"role": "assistant", |
|
"content": "์๋
ํ์ธ์! ์ฃผ์ด์ง ๋ฌธ์์ ๋ํด ๊ถ๊ธํ์ ๊ฒ์ด ์์ผ๋ฉด ์ธ์ ๋ ๋ฌผ์ด๋ด์ฃผ์ธ์!"}] |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
history = StreamlitChatMessageHistory(key="chat_messages") |
|
|
|
|
|
if query := st.chat_input("์ง๋ฌธ์ ์
๋ ฅํด์ฃผ์ธ์."): |
|
st.session_state.messages.append({"role": "user", "content": query}) |
|
|
|
with st.chat_message("user"): |
|
st.markdown(query) |
|
|
|
with st.chat_message("assistant"): |
|
chain = st.session_state.conversation |
|
|
|
with st.spinner("Thinking..."): |
|
result = chain({"question": query}) |
|
with get_openai_callback() as cb: |
|
st.session_state.chat_history = result['chat_history'] |
|
response = result['answer'] |
|
source_documents = result['source_documents'] |
|
|
|
|
|
tts = gTTS(text=response, lang='ko') |
|
tts.save('output.mp3') |
|
|
|
|
|
audio = AudioSegment.from_file("output.mp3", format="mp3") |
|
|
|
|
|
st.audio(audio.export(format='mp3').read(), start_time=0) |
|
|
|
st.markdown(response) |
|
with st.expander("์ฐธ๊ณ ๋ฌธ์ ํ์ธ"): |
|
st.markdown(source_documents[0].metadata['source'], help=source_documents[0].page_content) |
|
st.markdown(source_documents[1].metadata['source'], help=source_documents[1].page_content) |
|
st.markdown(source_documents[2].metadata['source'], help=source_documents[2].page_content) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
|
|
def tiktoken_len(text): |
|
tokenizer = tiktoken.get_encoding("cl100k_base") |
|
tokens = tokenizer.encode(text) |
|
return len(tokens) |
|
|
|
|
|
def get_text(docs): |
|
doc_list = [] |
|
|
|
for doc in docs: |
|
file_name = doc.name |
|
with open(file_name, "wb") as file: |
|
file.write(doc.getvalue()) |
|
logger.info(f"Uploaded {file_name}") |
|
if '.pdf' in doc.name: |
|
loader = PyMuPDFLoader(file_name) |
|
documents = loader.load_and_split() |
|
|
|
doc_list.extend(documents) |
|
return doc_list |
|
|
|
|
|
def get_text_chunks(text): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=100, |
|
length_function=tiktoken_len |
|
) |
|
chunks = text_splitter.split_documents(text) |
|
return chunks |
|
|
|
|
|
def get_vectorstore(text_chunks): |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="jhgan/ko-sroberta-multitask", |
|
model_kwargs={'device': 'cpu'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
vectordb = FAISS.from_documents(text_chunks, embeddings) |
|
return vectordb |
|
|
|
|
|
def get_conversation_chain(vetorestore, openai_api_key): |
|
llm = ChatOpenAI(openai_api_key=openai_api_key, model_name='gpt-3.5-turbo', temperature=0) |
|
conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=vetorestore.as_retriever(search_type='mmr', vervose=True), |
|
memory=ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer'), |
|
get_chat_history=lambda h: h, |
|
return_source_documents=True, |
|
verbose=True |
|
) |
|
|
|
return conversation_chain |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |