|
import os, tempfile
|
|
import re
|
|
import json
|
|
from pathlib import Path
|
|
import streamlit as st
|
|
st.set_page_config(page_title="QA_AIボット+要約")
|
|
st.title("Q&A-AIボット & 資料要約\n Demo running on just 4Core CPU/4GB RAM")
|
|
st.text("【取扱説明書等をアップロードすればその文書に関する質問にAIが回答します。資料の要約もできます。】")
|
|
st.text("""※アップロードされたファイルはサーバーには保存されず、ブラウザを閉じるとバイナリデータも自動的に消去されます。""")
|
|
|
|
from typing import Any, List
|
|
from datetime import datetime
|
|
|
|
from llama_index import (
|
|
download_loader,
|
|
VectorStoreIndex,
|
|
ServiceContext,
|
|
StorageContext,
|
|
SimpleDirectoryReader,
|
|
)
|
|
from llama_index.postprocessor import SentenceEmbeddingOptimizer
|
|
from llama_index.prompts.prompts import QuestionAnswerPrompt
|
|
from llama_index.readers import WikipediaReader, Document
|
|
|
|
from langchain_community.chat_models import ChatOllama
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
|
from langchain_community.document_loaders import TextLoader
|
|
from langchain_text_splitters import CharacterTextSplitter
|
|
|
|
from pysummarization.nlpbase.auto_abstractor import AutoAbstractor
|
|
from pysummarization.tokenizabledoc.mecab_tokenizer import MeCabTokenizer
|
|
from pysummarization.abstractabledoc.top_n_rank_abstractor import TopNRankAbstractor
|
|
from pysummarization.abstractabledoc.std_abstractor import StdAbstractor
|
|
from pysummarization.nlp_base import NlpBase
|
|
from pysummarization.similarityfilter.tfidf_cosine import TfIdfCosine
|
|
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
class StreamHandler(BaseCallbackHandler):
|
|
def __init__(self, initial_text="お調べしますので少々お待ち下さい。\n\n"):
|
|
self.initial_text = initial_text
|
|
self.text = initial_text
|
|
self.flag = True
|
|
def on_llm_start(self, *args: Any, **kwargs: Any):
|
|
self.text = self.initial_text
|
|
with st.chat_message("assistant"):
|
|
self.container = st.empty()
|
|
self.container.markdown(self.text+" "+" ")
|
|
print("LLM start: ",datetime.now())
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
if self.flag == True:
|
|
print("Stream start: ",datetime.now())
|
|
self.flag = False
|
|
self.text += token
|
|
self.container.markdown(self.text)
|
|
def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
|
|
st.session_state.messages.append({
|
|
"role": "assistant",
|
|
"content": self.text
|
|
})
|
|
print("LLM end: ",datetime.now())
|
|
|
|
from rake_ja import JapaneseRake
|
|
from rake_ja.tokenizer import Tokenizer
|
|
tok = Tokenizer()
|
|
ja_rake = JapaneseRake()
|
|
|
|
import chromadb
|
|
from llama_index.vector_stores import ChromaVectorStore
|
|
|
|
import wikipedia
|
|
class JaWikipediaReader(WikipediaReader):
|
|
def load_wiki(self, pages: List[str], **load_kwargs: Any) -> List[Document]:
|
|
"""Load data from the input directory.
|
|
Args:
|
|
pages (List[str]): List of pages to read.
|
|
"""
|
|
wikipedia.set_lang('ja')
|
|
results = []
|
|
for page in pages:
|
|
page_content = wikipedia.page(page, **load_kwargs).content
|
|
results.append(page_content)
|
|
return results
|
|
|
|
from duckduckgo_search import DDGS
|
|
maxsearch_results = 3
|
|
def search_general(input_text):
|
|
with DDGS() as ddgs:
|
|
results = [r for r in ddgs.text(f"{input_text}", region="jp-jp", timelimit="y", max_results=maxsearch_results, safesearch="off")]
|
|
print(results)
|
|
return results
|
|
|
|
STORAGE_DIR = "./storage/"
|
|
TEMP_DIR = "./temp_data/"
|
|
HISTORY_DIR = "./history/"
|
|
SUMMARY_DIR = "./summary/"
|
|
|
|
os.makedirs(STORAGE_DIR, exist_ok=True)
|
|
os.makedirs(TEMP_DIR, exist_ok=True)
|
|
os.makedirs(HISTORY_DIR, exist_ok=True)
|
|
os.makedirs(SUMMARY_DIR, exist_ok=True)
|
|
|
|
|
|
class PDFReader:
|
|
def __init__(self):
|
|
self.pdf_reader = download_loader("PDFReader", custom_path="local_dir")()
|
|
def load_data(self, file_name):
|
|
return self.pdf_reader.load_data(file=Path(file_name))
|
|
|
|
ollama_url = "http://localhost:11434"
|
|
ollama_remote = "https://ai.pib.co.jp"
|
|
llamacpp_url = "http://localhost:8000/v1"
|
|
LM_Studio_url = "http://localhost:1234/v1"
|
|
class QAResponseGenerator:
|
|
def __init__(self, selected_model, pdf_reader, device=None):
|
|
stream_handler = StreamHandler()
|
|
|
|
self.llm = ChatOllama(base_url=ollama_url, model=selected_model, streaming=True, callbacks=[stream_handler], verbose=True)
|
|
self.pdf_reader = pdf_reader
|
|
if selected_model == "llama3":
|
|
self.QA_PROMPT_TMPL = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nあなたは日本人のコールセンター管理者です。次の質問に日本語で回答してください。<|eot_id|><|start_header_id|>user<|end_header_id|>\n{query_str}<|eot_id|><start_header_id|>assistant<|end_header_id|><|eot_id|>"
|
|
|
|
if selected_model == "Elyza":
|
|
self.QA_PROMPT_TMPL =("""<s>[INST] <<SYS>>#あなたは誠実で優秀な日本人のコールセンター管理者です。<</SYS>>{query_str} [/INST]""")
|
|
|
|
self.CHAT_REFINE_PROMPT_TMPL_MSGS = ("""<s>[INST] <<SYS>>
|
|
"あなたは、既存の回答を改良する際に2つのモードで厳密に動作するQAシステムのエキスパートです。\n"
|
|
"1. 新しいコンテキストを使用して元の回答を**書き直す**。\n"
|
|
"2. 新しいコンテキストが役に立たない場合は、元の回答を**繰り返す**。\n"
|
|
"回答内で元の回答やコンテキストを直接参照しないでください。\n"
|
|
"疑問がある場合は、元の答えを繰り返してください。"
|
|
"New Context: {context_msg}\n"
|
|
<</SYS>>
|
|
"Query: {query_str}\n"
|
|
"Original Answer: {existing_answer}\n"
|
|
"New Answer: "
|
|
[/INST]"""
|
|
)
|
|
if selected_model == "swallow":
|
|
self.QA_PROMPT_TMPL ='### 指示:{query_str}\n ### 応答:'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if selected_model == "mist300":
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{query_str}\n[SEP]\n応答:\n'
|
|
|
|
|
|
|
|
|
|
if selected_model == "TinyLlama":
|
|
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = "<|system|>\nあなたは日本人です。</s>\n<|user|>\n{query_str}</s>\n<|assistant|>"
|
|
if selected_model == "tinycodellamajp":
|
|
self.QA_PROMPT_TMPL = "<|im_start|>user\n{query_str}<|im_end|>\n<|im_start|>assistant\n"
|
|
if selected_model == "tinyllamamoe":
|
|
self.QA_PROMPT_TMPL = "<|システム|>\nあなたは日本人のコールセンターの管理者です。参考情報を元に、日本語で質問に回答してください。</s>\n<|ユーザー|>\n{query_str}</s>\n<|アシスタント|>"
|
|
|
|
|
|
|
|
if selected_model == "phillama":
|
|
self.QA_PROMPT_TMPL = "<|system|>\nあなたは日本人のコールセンターの管理者です。参考情報を元に質問に日本語で回答してください。<|end|><|user|>{query_str}<|end|><|assistant|>"
|
|
if selected_model == "stabilty":
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = r"""<s>\nあなたは優秀なコールセンター管理者です。参考文献を元に次の質問に回答して下さい。\n[SEP]\n指示:\n{query_str}\n[SEP]\n入力:\n{context_str}\n[SEP]\n応答:\n"""
|
|
|
|
|
|
if selected_model == "stabilq4":
|
|
|
|
self.QA_PROMPT_TMPL = """以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
|
|
|
|
### 指示:
|
|
{query_str}
|
|
|
|
### 入力:
|
|
{context_str}
|
|
|
|
### 応答:
|
|
"""
|
|
if selected_model == "stabilq3":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = '### 指示:{query_str}\n### 応答:'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if selected_model == "stabilzephyr":
|
|
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = """
|
|
<|system|>あなたは優秀なコールセンター管理者です。参考情報を元に質問に回答して下さい。<|endoftext|>
|
|
|
|
<|user|>{query_str}<|endoftext|>
|
|
|
|
<|assistant|>
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if selected_model == "stabil2instruct":
|
|
|
|
self.QA_PROMPT_TMPL = r"""
|
|
<|system|>あなたは役立つアシスタントです。<|endoftext|>
|
|
|
|
<|user|>{query_str}<|endoftext|>
|
|
|
|
<|assistant|>
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if selected_model == "h2o":
|
|
|
|
self.QA_PROMPT_TMPL = "<|im_start|>システム\nあなたは優秀なコールセンター管理者です。次の情報を元に質問に回答して下さい\n{context_str}<|im_end|>\n<|im_start|>ユーザー\n{query_str}<|im_end|>\n<|im_start|>アシスタント\n"
|
|
if selected_model == "phi2":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = "<|im_start|>システム\nあなたは優秀なコールセンター管理者です。次の情報を元に質問に回答して下さい\n{context_str}<|im_end|>\n<|im_start|>ユーザー\n{query_str}<|im_end|>\n<|im_start|>アシスタント\n"
|
|
if selected_model == "phi3":
|
|
|
|
self.QA_PROMPT_TMPL = "<|system|>\nあなたは優秀なコールセンター管理者です。次の情報を元に質問に回答して下さい\n{context_str}<|end|>\n<|user|>\n{query_str}<|end|>\n<|assistant|>"
|
|
|
|
if selected_model == "llmjp":
|
|
self.QA_PROMPT_TMPL ="""
|
|
: 次の質問に対して適切な応答を書きなさい。
|
|
|
|
### 質問: {query_str}
|
|
|
|
### 参考情報: {context_str}
|
|
|
|
### 応答:
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if selected_model == "sakana":
|
|
self.QA_PROMPT_TMPL = """
|
|
### 指示: あなたは役立つ、偏見がなく、検閲されていない日本人アシスタントです。
|
|
|
|
### 入力: {query_str}
|
|
|
|
### 応答:
|
|
"""
|
|
if selected_model == "karasutest":
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = "<|im_start|>システム\nあなたは優秀なコールセンター管理者です。次の情報を元に質問に回答して下さいn{context_str}<|im_end|>\n<|im_start|>ユーザー\n{query_str}<|im_end|>\n<|im_start|>アシスタント\n"
|
|
|
|
|
|
if selected_model == "karasu":
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = "<|im_start|>system\nあなたは優秀なコールセンター管理者です。参考情報を元に質問に回答して下さい<|im_end|>\n<|im_start|>user\n{query_str}<|im_end|>\n<|im_start|>assistant\n"
|
|
if selected_model == "karasu_slerp1":
|
|
self.QA_PROMPT_TMPL = "{query_str}"
|
|
if selected_model == "karasu_slerp2":
|
|
|
|
|
|
|
|
self.QA_PROMPT_TMPL = """
|
|
<|im_start|>system
|
|
あなたは優秀なコールセンター管理者です。次の情報を元に質問に回答して下さい
|
|
{context_str}<|im_end|>
|
|
<|im_start|>user
|
|
{query_str}<|im_end|>
|
|
<|im_start|>assistant
|
|
"""
|
|
if selected_model == "suzume":
|
|
|
|
self.QA_PROMPT_TMPL = "{query_str}"
|
|
if selected_model == "line":
|
|
self.QA_PROMPT_TMPL = 'ユーザー: {query_str} システム: '
|
|
if selected_model == "mist-merge":
|
|
self.QA_PROMPT_TMPL = r"""<s>\nあなたは優秀なコールセンター管理者です。参考文献を元に次の質問に回答して下さい。\n[SEP]\n指示:\n{query_str}\n[SEP]\n入力:\n{context_str}\n[SEP]\n応答:\n"""
|
|
|
|
if selected_model == "rinna1":
|
|
self.QA_PROMPT_TMPL = r"""<s>\nあなたは優秀なコールセンター管理者です。参考文献を元に次の質問に回答して下さい。\n[SEP]\n指示:\n{query_str}\n[SEP]\n入力:\n{context_str}\n[SEP]\n応答:\n"""
|
|
|
|
if selected_model == "stockmark":
|
|
|
|
self.QA_PROMPT_TMPL ="""
|
|
#「コンテキスト情報」を元に、質問に回答してください。
|
|
|
|
###質問:{query_str}
|
|
|
|
###「コンテキスト情報」:{context_str}
|
|
|
|
###回答:
|
|
"""
|
|
if selected_model == "rinna":
|
|
self.QA_PROMPT_TMPL ='ユーザー: {query_str} システム: '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if device is None:
|
|
|
|
device = "cpu"
|
|
|
|
self.device = "cpu"
|
|
|
|
self.embed_model = HuggingFaceEmbeddings(model_name="cheonboy/sentence_embedding_japanese", model_kwargs={"device": self.device})
|
|
|
|
self.service_context = ServiceContext.from_defaults(llm=self.llm, embed_model=self.embed_model)
|
|
|
|
def generate(self, question, file_name, uploaded_file, maxsearch_results):
|
|
|
|
print(question)
|
|
text = question
|
|
|
|
|
|
|
|
tokens = tok.tokenize(text)
|
|
ja_rake.extract_keywords_from_text(tokens)
|
|
ranked_phrases = ja_rake.get_ranked_phrases_with_scores()
|
|
print("ranked_phrases: ", ranked_phrases)
|
|
keyphrases = [x[0] for x in ranked_phrases], [x[1] for x in ranked_phrases]
|
|
keywords = [x[1] for x in ranked_phrases]
|
|
print("keyphrases: ", keyphrases, "keyphrases[0]: ", keyphrases[0])
|
|
print("keypwords: ", keywords)
|
|
|
|
|
|
try:
|
|
db2 = chromadb.PersistentClient(path="./chroma_db")
|
|
chroma_collection = db2.get_collection(file_name)
|
|
print("chroma collection count: ",chroma_collection.count())
|
|
|
|
|
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
|
pdf_index = VectorStoreIndex.from_vector_store(
|
|
vector_store,
|
|
service_context=self.service_context,
|
|
)
|
|
print("load existing file..")
|
|
except:
|
|
print("loaddata from pdf file")
|
|
pdf_documents = self.pdf_reader.load_data(file_name)
|
|
|
|
db = chromadb.PersistentClient(path="./chroma_db")
|
|
chroma_collection = db.get_or_create_collection(
|
|
name=uploaded_file.name,
|
|
metadata={"hnsw:space": "cosine"}
|
|
)
|
|
print("chroma collection count: ",chroma_collection.count())
|
|
vector_store = ChromaVectorStore(
|
|
chroma_collection=chroma_collection,
|
|
)
|
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
pdf_index = VectorStoreIndex.from_documents(
|
|
pdf_documents, storage_context=storage_context, service_context=self.service_context
|
|
)
|
|
print("save new file..")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pdf_engine = pdf_index.as_query_engine(
|
|
similarity_top_k=2,
|
|
text_qa_template=QuestionAnswerPrompt(self.QA_PROMPT_TMPL),
|
|
required_keywords= keywords[0] if keyphrases[0][0] > 0.5 else "",
|
|
|
|
|
|
node_postprocessors=[SentenceEmbeddingOptimizer(embed_model=self.service_context.embed_model, threshold_cutoff=0.4)],
|
|
|
|
|
|
)
|
|
try:
|
|
pdf_result = pdf_engine.query(question)
|
|
|
|
print("Num of nodes:",len(pdf_result.source_nodes))
|
|
if len(pdf_result.source_nodes) == 1:
|
|
print("source_nodes.SCORE: ", pdf_result.source_nodes[0].score)
|
|
else:
|
|
print("source_nodes.SCORE: ", pdf_result.source_nodes[0].score, pdf_result.source_nodes[1].score)
|
|
|
|
|
|
|
|
|
|
|
|
print("count before", chroma_collection.count())
|
|
doc_to_update = chroma_collection.get()
|
|
chroma_collection.delete(ids=[id for id in doc_to_update["ids"]])
|
|
|
|
|
|
return pdf_result.response, pdf_result.source_nodes[0].text.strip().replace(" ", "")
|
|
|
|
except ValueError as e:
|
|
print("PDF Error: ",e)
|
|
|
|
|
|
|
|
count = len(keywords)
|
|
search_words = ""
|
|
for index in range(count):
|
|
search_words += " " + keywords[index]
|
|
print("serch_words: ",search_words)
|
|
try:
|
|
|
|
wikidocuments = JaWikipediaReader().load_wiki(pages=[search_words])
|
|
strwikidocuments = "".join(map(str, wikidocuments))
|
|
|
|
with open("./temp_data/wiki/wikisearch.txt", "w") as fp:
|
|
fp.write(strwikidocuments)
|
|
wiki_documents = SimpleDirectoryReader("./temp_data//wiki").load_data()
|
|
wiki_index = VectorStoreIndex.from_documents(wiki_documents, service_context=self.service_context)
|
|
wiki_index.storage_context.persist(persist_dir=f"{STORAGE_DIR}"+"wikisearch.txt")
|
|
wiki_engine = wiki_index.as_query_engine(
|
|
similarity_top_k=2,
|
|
|
|
refine_template=QuestionAnswerPrompt(self.CHAT_REFINE_PROMPT_TMPL_MSGS),
|
|
|
|
|
|
|
|
|
|
)
|
|
|
|
wiki_result = wiki_engine.query(question)
|
|
|
|
return wiki_result.response, wiki_result.get_formatted_sources(1000)
|
|
except Exception as e:
|
|
|
|
print("Wiki Error: ",e)
|
|
|
|
try:
|
|
search_results = search_general(search_words)
|
|
|
|
|
|
with open("./temp_data/duck/ducksearch.txt", "w") as fp:
|
|
contents = ""
|
|
for index in range(maxsearch_results):
|
|
contents += json.dumps(search_results[index]).encode().decode('unicode-escape')
|
|
fp.write(contents)
|
|
web_documents = SimpleDirectoryReader("./temp_data/duck").load_data()
|
|
web_index = VectorStoreIndex.from_documents(web_documents, service_context=self.service_context)
|
|
web_index.storage_context.persist(persist_dir=f"{STORAGE_DIR}"+"ducksearch.txt")
|
|
web_engine = web_index.as_query_engine(
|
|
similarity_top_k=3,
|
|
|
|
refine_template=QuestionAnswerPrompt(self.CHAT_REFINE_PROMPT_TMPL_MSGS),
|
|
required_keywords=[search_words],
|
|
|
|
|
|
|
|
)
|
|
|
|
web_result = web_engine.query(question)
|
|
|
|
return web_result.response, web_result.get_formatted_sources(1000)
|
|
except Exception as e:
|
|
|
|
print("Duck Error: ",e)
|
|
|
|
return "現在Web検索が停止中。", "Wikipediaの検索結果で代替えします。"
|
|
|
|
def save_uploaded_file(uploaded_file, save_dir):
|
|
try:
|
|
with open(os.path.join(save_dir, uploaded_file.name), "wb") as f:
|
|
f.write(uploaded_file.getvalue())
|
|
return True
|
|
except Exception as e:
|
|
st.error(f"Error: {e}")
|
|
return False
|
|
|
|
def upload_pdf_file():
|
|
uploaded_file = st.sidebar.file_uploader("ファイルアップロード", type=["pdf", "txt"])
|
|
print("uploaded_file",uploaded_file)
|
|
if uploaded_file is not None:
|
|
st.success(f"ファイル {uploaded_file.name} のアップロードが完了しました.")
|
|
return uploaded_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pdf_reader = PDFReader()
|
|
|
|
|
|
uploaded_file = st.sidebar.file_uploader("ファイルアップロード", type=["pdf", "txt"])
|
|
if uploaded_file is not None:
|
|
st.sidebar.success(f"{uploaded_file.name} のアップロード完了")
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|
tmp_file.write(uploaded_file.read())
|
|
|
|
selected_model = st.sidebar.selectbox("生成AIモデルの選択", ["標準", "次候補", "試行版", "llmjp"])
|
|
|
|
if selected_model == "標準":
|
|
selected_model = "line"
|
|
elif selected_model == "次候補":
|
|
selected_model = "karasu_slerp2"
|
|
elif selected_model == "試行版":
|
|
selected_model = "mist300"
|
|
elif selected_model == "llmjp":
|
|
selected_model = "llmjp"
|
|
|
|
choice = st.radio("参照情報を表示:", ["表示する", "表示しない"])
|
|
question = st.text_input("質問入力")
|
|
response_generator = QAResponseGenerator(selected_model, pdf_reader)
|
|
|
|
|
|
submit_question = st.button("質問")
|
|
clear_chat = st.sidebar.button("履歴消去")
|
|
st.session_state.last_updated = ""
|
|
st.session_state.last_updated_json = []
|
|
|
|
|
|
if "chat_history" not in st.session_state:
|
|
st.session_state["chat_history"] = []
|
|
|
|
if clear_chat:
|
|
st.session_state["chat_history"] = []
|
|
st.session_state.last_updated = ""
|
|
st.session_state.last_updated_json = []
|
|
|
|
|
|
if submit_question:
|
|
print("pushed question button!")
|
|
if question:
|
|
|
|
response, source = response_generator.generate(question, tmp_file.name, uploaded_file, maxsearch_results)
|
|
|
|
|
|
|
|
st.session_state["chat_history"].append({"user": question})
|
|
st.session_state["chat_history"].append({"assistant": response})
|
|
if choice == "表示する":
|
|
source = source.replace("page_label", "該当ページ番号").replace("file_name", " ファイル名:"+uploaded_file.name)
|
|
response = f"\n\n参照した情報は次の通りです:\n\n{source}"
|
|
|
|
|
|
with st.chat_message("assistant"):
|
|
st.markdown(response)
|
|
|
|
st.session_state.last_updated += json.dumps(st.session_state["chat_history"],indent=2, ensure_ascii=False)
|
|
with open("./history/chat_history.txt","w") as o:
|
|
print(st.session_state.last_updated,sep="",file=o)
|
|
with open("./history/chat_history.txt", "r") as f:
|
|
history_str = f.read()
|
|
history_json = json.loads(history_str)
|
|
|
|
with open("./history/chat_history.json", "w") as o:
|
|
json.dump(history_json, o, ensure_ascii=False)
|
|
|
|
elif st.button("要約開始"):
|
|
stream_handler = StreamHandler()
|
|
try:
|
|
with st.spinner('要約中...'):
|
|
|
|
if uploaded_file.name.endswith(".pdf"):
|
|
pdf_reader = download_loader("PDFReader", custom_path="local_dir")()
|
|
|
|
docs = pdf_reader.load_data(file=tmp_file.name)
|
|
alltext = ""
|
|
all_count = len(docs)
|
|
for count in range(all_count):
|
|
text = docs[count].text
|
|
alltext += text
|
|
|
|
elif uploaded_file.name.endswith(".txt"):
|
|
|
|
loader = TextLoader(tmp_file.name)
|
|
documents = loader.load()
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
docs = text_splitter.split_documents(documents)
|
|
allcount = len(docs)
|
|
alltext = ""
|
|
for n in range(allcount):
|
|
text = docs[n].page_content
|
|
alltext += text
|
|
|
|
|
|
auto_abstractor = AutoAbstractor()
|
|
auto_abstractor.tokenizable_doc = MeCabTokenizer()
|
|
auto_abstractor.delimiter_list = ["。", "\n"]
|
|
abstractable_doc = StdAbstractor()
|
|
result_dict = auto_abstractor.summarize(alltext, abstractable_doc)
|
|
summary1 = "Original word count(原文文字数): "+str(len(alltext))+"\n\n"
|
|
for x in result_dict["summarize_result"]:
|
|
summary1 += x+"\n"
|
|
summary1 = summary1+"\nNormal Summary word count(標準要約文字数): "+str(len(summary1))+" compression ratio(圧縮率): "+'{:.0%}'.format(len(summary1)/len(alltext))
|
|
similarity_limit = 0.8
|
|
nlp_base = NlpBase()
|
|
nlp_base.tokenizable_doc = MeCabTokenizer()
|
|
similarity_filter = TfIdfCosine()
|
|
|
|
|
|
|
|
similarity_filter.nlp_base = nlp_base
|
|
similarity_filter.similarity_limit = similarity_limit
|
|
result_dict2 = auto_abstractor.summarize(alltext, abstractable_doc, similarity_filter)
|
|
summary2 = ''
|
|
for sentence in result_dict2["summarize_result"]:
|
|
summary2 += sentence+"\n"
|
|
summary2 = summary2+"\nFiltered Summary word count(フィルタ要約文字数): "+str(len(summary2))+" compression ratio(圧縮率): "+'{:.0%}'.format(len(summary2)/len(alltext))
|
|
auto_abstractor = AutoAbstractor()
|
|
auto_abstractor.tokenizable_doc = MeCabTokenizer()
|
|
auto_abstractor.delimiter_list = ["。", "\n"]
|
|
abstractable_doc = TopNRankAbstractor()
|
|
result_dict = auto_abstractor.summarize(alltext, abstractable_doc)
|
|
summary3 = ''
|
|
for x in result_dict["summarize_result"]:
|
|
summary3 += x+"\n"
|
|
summary3 = summary3+"\nTopNRank word count(トップNランク要約文字数): "+str(len(summary3))+" compression ratio(圧縮率): "+'{:.0%}'.format(len(summary3)/len(alltext))
|
|
|
|
st.success("[pysum_nomal]\n"+summary1)
|
|
st.success("[pysum_filtered]\n"+summary2)
|
|
st.success("[pysum_topnrank]\n"+summary3)
|
|
st.download_button('要約1をダウンロード', summary1, file_name=uploaded_file.name[:uploaded_file.name.find('.')]+'_summary1.txt')
|
|
st.download_button('要約2をダウンロード', summary2, file_name=uploaded_file.name[:uploaded_file.name.find('.')]+'_summary2.txt')
|
|
st.download_button('要約2をダウンロード', summary3, file_name=uploaded_file.name[:uploaded_file.name.find('.')]+'_summary3.txt')
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
if 'NoneType' in str(e):
|
|
st.success("ファイルがアップロードされてません。先にアップロードして下さい。")
|
|
else:
|
|
st.exception(f"An error occurred: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
if "'tmp_file' referenced before assignment" in str(e) :
|
|
st.success("ファイルがアップロードされてません。先にアップロードして下さい。")
|
|
else:
|
|
st.exception(f"An error occurred: {e}")
|
|
|