Spaces:
Sleeping
Sleeping
import arrow | |
import gradio as gr | |
import os | |
import re | |
import pandas as pd | |
from pathlib import Path | |
from time import sleep | |
from tqdm import tqdm | |
from api_calls import * | |
ROOT_DIR = Path(__file__).resolve().parents[0] | |
default_co_ids = ["2330", "2317", "1301", "2303", "1101", "2311", "2002", "2412"] | |
default_company_names = ["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"] | |
default_industries = ["半導體業", "水泥工業", "電子零組件業", "電子通路業", "電腦及週邊設備業", "其他電子業", "金融保險業", "文化創意業", "鋼鐵工業", "通信網路業", "電子商務業"] | |
def load_default_filter_data(filter_type): | |
d = { | |
"co_id": default_co_ids, | |
"company_name": default_company_names, | |
"industry": default_industries, | |
}[filter_type] | |
return gr.Dropdown.update(choices=d) | |
def markdown2html(md: str) -> str: | |
import markdown | |
return markdown.markdown(md) | |
def export_to_txt(output): | |
today_dt_str = arrow.now(tz="Asia/Taipei").format("YYYYMMDDTHHmmss") | |
with open(f"esg_report_summary-{today_dt_str}.txt", "w") as f: | |
f.write(output) | |
return f"esg_report_summary-{today_dt_str}.txt" | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, gr.Textbox(value="", interactive=False) | |
def esgsumm_exe(openai_model_name, year, target_type, target_value, tone): | |
query = "根據您提供的相關資訊和偏好語氣,以繁體中文生成一份符合GRI標準的報告草稿。報告將包括每個GRI披露項目的標題、相關公司行為的概要,以及公司的具體措施和效果。" | |
response = api_rag_summ_chain_demo(openai_model_name, query, year, target_type, target_value, tone) | |
full_anwser = "" | |
for chunk in response.iter_content(chunk_size=32): | |
if chunk: | |
try: | |
_c = chunk.decode('utf-8') | |
except UnicodeDecodeError: | |
_c = " " | |
full_anwser += _c | |
yield full_anwser | |
# for character in response: | |
# full_text += character | |
# yield full_text | |
def esgqabot(history, openai_model_name, year, target_type, target_value): | |
query = history[-1][0] | |
response = api_rag_qa_chain_demo(openai_model_name, query, year, target_type, target_value, history[:-1]) | |
history[-1][1] = "" | |
for chunk in response.iter_content(chunk_size=32): | |
if chunk: | |
try: | |
_c = chunk.decode('utf-8') | |
except UnicodeDecodeError: | |
_c = " " | |
history[-1][1] += _c | |
yield history | |
# for character in response: | |
# history[-1][1] += character | |
# yield history | |
css = """ | |
#center {text-align: center} | |
footer {visibility: hidden} | |
a {color: rgb(255, 206, 10) !important} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Monochrome(neutral_hue="green", primary_hue="slate")) as demo: | |
gr.HTML("<h1>ESG RAG Playground</h1>", elem_id="center") | |
gr.Markdown("Made by `Abao`", elem_id="center") | |
gr.Markdown("---") | |
# esgsumm | |
with gr.Tab("ESG Report Summarization"): | |
gr.HTML("<h2>Report Summarization</h2><p>Summarize report with tone & schema.</p>", elem_id="center") | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("### Configuration", elem_id="center") | |
esgsumm_report_tone = gr.Dropdown( | |
value="精確", | |
label="Tone", | |
choices=["富有創意", "中庸", "精確"]) | |
esgsumm_openai_model_name = gr.Dropdown( | |
value="gpt-4-turbo-preview", | |
label="OpenAI Model", | |
choices=["gpt-4-turbo-preview", "gpt-3.5-turbo"]) | |
esgsumm_year = gr.Dropdown( | |
value="111", | |
label="Year", | |
choices=["111", "110", "109"] | |
) | |
esgsumm_target_type = gr.Dropdown( | |
value="company_name", | |
label="Target Type", | |
choices=["company_name", "industry", "co_id"] | |
) | |
esgsumm_target_value = gr.Dropdown( | |
value="台積電", | |
label="Target Value", | |
choices=["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"] | |
) | |
esgsumm_report_gen_button = gr.Button("Generate Report") | |
with gr.Column(): | |
gr.Markdown("## Generate ESG Summarization", elem_id="center") | |
with gr.Accordion("Revise Your Prompt", open=False): | |
esgsumm_checkbox_replace = gr.Checkbox(label="Replace with new prompt") | |
esgsumm_prompt_tmpl = gr.Textbox( | |
label="希望用於本次問答的prompt", | |
info="必須使用到的變數:{filtered_data}、{query}", | |
value="", | |
interactive=True, | |
) | |
esgsumm_report_output = gr.Textbox( | |
label="Report Output", | |
interactive=False, | |
scale=4, | |
) | |
esgsumm_report_output_html = gr.HTML() | |
esgsumm_download_btn = gr.Button("Export Summary") | |
esgsumm_download_file = gr.File( | |
label="Download Summary Text", file_types=[".txt"] | |
) | |
# esgqa | |
with gr.Tab("ESG QA"): | |
gr.HTML("<h2>ParallelQA (GPT-4 like)</h2><p>Test multiple LLMs at once.</p>", elem_id="center") | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("### Configuration", elem_id="center") | |
esgqa_openai_model_name = gr.Dropdown( | |
value="gpt-4-turbo-preview", | |
label="OpenAI Model", | |
choices=["gpt-4-turbo-preview", "gpt-3.5-turbo"]) | |
esgqa_year = gr.Dropdown( | |
value="111", | |
label="Year", | |
choices=["111", "110", "109"] | |
) | |
esgqa_target_type = gr.Dropdown( | |
value="company_name", | |
label="Target Type", | |
choices=["company_name", "industry", "co_id"] | |
) | |
esgqa_target_value = gr.Dropdown( | |
value="台積電", | |
label="Target Value", | |
choices=["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"] | |
) | |
with gr.Column(): | |
gr.Markdown("## Chat with ESGQABot", elem_id="center") | |
with gr.Accordion("Revise Your Prompt", open=False): | |
esgqa_checkbox_replace = gr.Checkbox(label="Replace with new prompt") | |
esgqa_prompt_tmpl = gr.Textbox( | |
label="希望用於本次問答的prompt", | |
info="必須使用到的變數:{filtered_data}、{query}", | |
value="", | |
interactive=True, | |
) | |
esgqa_chatbot = gr.Chatbot( | |
[(None, "我是 ESGQABot\n有什麼能為您服務的嗎?")], | |
elem_id="chatbot", | |
scale=1, | |
height=700, | |
bubble_full_width=False | |
) | |
with gr.Row(): | |
esgqa_chatbot_input = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Enter text and press enter, or upload an image", | |
container=False, | |
) | |
esgqa_chat_btn = gr.Button("💬") | |
# esgsumm | |
esgsumm_target_type.change( | |
load_default_filter_data, [esgsumm_target_type], [esgsumm_target_value] | |
) | |
esgsumm_report_gen_button.click( | |
esgsumm_exe, [esgsumm_openai_model_name, esgsumm_year, esgsumm_target_type, esgsumm_target_value, esgsumm_report_tone], [esgsumm_report_output] | |
).then( | |
markdown2html, [esgsumm_report_output], [esgsumm_report_output_html] | |
) | |
esgsumm_download_btn.click( | |
fn=export_to_txt, | |
inputs=[esgsumm_report_output], | |
outputs=esgsumm_download_file, | |
) | |
# esgqa | |
esgqa_target_type.change( | |
load_default_filter_data, [esgqa_target_type], [esgqa_target_value] | |
) | |
esgqa_chatbot_input.submit( | |
add_text, [esgqa_chatbot, esgqa_chatbot_input], [esgqa_chatbot, esgqa_chatbot_input], queue=False | |
).then( | |
esgqabot, [esgqa_chatbot, esgqa_openai_model_name, esgqa_year, esgqa_target_type, esgqa_target_value], esgqa_chatbot, api_name="esgqa_response" | |
).then( | |
lambda: gr.Textbox(interactive=True), None, [esgqa_chatbot_input], queue=False | |
) | |
esgqa_chat_btn.click( | |
add_text, [esgqa_chatbot, esgqa_chatbot_input], [esgqa_chatbot, esgqa_chatbot_input], queue=False | |
).then( | |
esgqabot, [esgqa_chatbot, esgqa_openai_model_name, esgqa_year, esgqa_target_type, esgqa_target_value], esgqa_chatbot, api_name="esgqa_response" | |
).then( | |
lambda: gr.Textbox(interactive=True), None, [esgqa_chatbot_input], queue=False | |
) | |
esgqa_chatbot.like(print_like_dislike, None, None) | |
if __name__ == "__main__": | |
demo.queue().launch(max_threads=10) | |