LongCite / app.py
crazyjames's picture
update
977fb98
raw
history blame
7.21 kB
import subprocess
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
import docx
import PyPDF2
def convert_to_txt(file):
doc_type = file.split(".")[-1].strip()
if doc_type in ["txt", "md", "py"]:
data = [file.read().decode('utf-8')]
elif doc_type in ["pdf"]:
pdf_reader = PyPDF2.PdfReader(file)
data = [pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages))]
elif doc_type in ["docx"]:
doc = docx.Document(file)
data = [p.text for p in doc.paragraphs]
else:
raise gr.Error(f"ERROR: unsupported document type: {doc_type}")
text = "\n\n".join(data)
return text
model_name = "THUDM/LongCite-glm4-9b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map='auto')
html_styles = """<style>
.reference {
color: blue;
text-decoration: underline;
}
.highlight {
background-color: yellow;
}
.label {
font-family: sans-serif;
font-size: 16px;
font-weight: bold;
}
.Bold {
font-weight: bold;
}
.statement {
background-color: lightgrey;
}
</style>\n"""
def process_text(text):
special_char={
'&': '&amp;',
'\'': '&apos;',
'"': '&quot;',
'<': '&lt;',
'>': '&gt;',
'\n': '<br>',
}
for x, y in special_char.items():
text = text.replace(x, y)
return text
def convert_to_html(statements, clicked=-1):
html = html_styles + '<br><span class="label">Answer:</span><br>\n'
all_cite_html = []
clicked_cite_html = None
cite_num2idx = {}
idx = 0
for i, js in enumerate(statements):
statement, citations = process_text(js['statement']), js['citation']
if clicked == i:
html += f"""<span class="statement">{statement}</span>"""
else:
html += f"<span>{statement}</span>"
if citations:
cite_html = []
idxs = []
for c in citations:
idx += 1
idxs.append(str(idx))
cite = '[Sentence: {}-{}\t|\tChar: {}-{}]<br>\n<span {}>{}</span>'.format(c['start_sentence_idx'], c['end_sentence_idx'], c['start_char_idx'], c['end_char_idx'], 'class="highlight"' if clicked==i else "", process_text(c['cite'].strip()))
cite_html.append(f"""<span><span class="Bold">Snippet [{idx}]:</span><br>{cite}</span>""")
all_cite_html.extend(cite_html)
cite_num = '[{}]'.format(','.join(idxs))
cite_num2idx[cite_num] = i
cite_num_html = """ <span class="reference" style="color: blue" id={}>{}</span>""".format(i, cite_num)
html += cite_num_html
html += '\n'
if clicked == i:
clicked_cite_html = html_styles + """<br><span class="label">Citations of current statement:</span><br><div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format("<br><br>\n".join(cite_html))
all_cite_html = html_styles + """<br><span class="label">All citations:</span><br>\n<div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format("<br><br>\n".join(all_cite_html).replace('<span class="highlight">', '<span>') if len(all_cite_html) else "No citation in the answer")
return html, all_cite_html, clicked_cite_html, cite_num2idx
def render_context(file):
if hasattr(file, "name"):
context = convert_to_txt(file.name)
return gr.Textbox(context, visible=True)
else:
raise gr.Error(f"ERROR: no uploaded document")
def run_llm(context, query):
if not context:
raise gr.Error("Error: no uploaded document")
if not query:
raise gr.Error("Error: no query")
result = model.query_longcite(context, query, tokenizer=tokenizer, max_input_length=128000, max_new_tokens=1024)
all_statements = result['all_statements']
answer_html, all_cite_html, clicked_cite_html, cite_num2idx_dict = convert_to_html(all_statements)
cite_nums = list(cite_num2idx_dict.keys())
return {
statements: gr.JSON(all_statements),
answer: gr.HTML(answer_html, visible=True),
all_citations: gr.HTML(all_cite_html, visible=True),
cite_num2idx: gr.JSON(cite_num2idx_dict),
citation_choices: gr.Radio(cite_nums, visible=len(cite_nums)>0),
clicked_citations: gr.HTML(visible=False),
}
def chose_citation(statements, cite_num2idx, clicked_cite_num):
clicked = cite_num2idx[clicked_cite_num]
answer_html, _, clicked_cite_html, _ = convert_to_html(statements, clicked=clicked)
return {
answer: gr.HTML(answer_html, visible=True),
clicked_citations: gr.HTML(clicked_cite_html, visible=True),
}
with gr.Blocks() as demo:
gr.Markdown(
"""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
LongCite-glm4-9b Huggingface Space🤗
</div>
<div style="text-align: center;">
<a href="https://huggingface.co/THUDM/LongCite-glm4-9b">🤗 Model Hub</a> |
<a href="https://github.com/THUDM/LongCite">🌐 Github</a> |
<a href="https://arxiv.org/pdf/">📜 arxiv </a>
</div>
<br>
<div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;">
If you plan to use it long-term, please consider deploying the model or forking this space yourself.
</div>
"""
)
with gr.Row():
with gr.Column(scale=4):
file = gr.File(label="Upload a document (supported type: pdf, docx, txt, md, py)")
query = gr.Textbox(label='Question')
submit_btn = gr.Button("Submit")
with gr.Column(scale=4):
context = gr.Textbox(label="Document content", autoscroll=False, placeholder="No uploaded document.", max_lines=10, visible=False)
file.upload(render_context, [file], [context])
with gr.Row():
with gr.Column(scale=4):
statements = gr.JSON(label="statements", visible=False)
answer = gr.HTML(label="Answer", visible=True)
cite_num2idx = gr.JSON(label="cite_num2idx", visible=False)
citation_choices = gr.Radio(label="Chose citations for details", visible=False, interactive=True)
with gr.Column(scale=4):
clicked_citations = gr.HTML(label="Citations of the chosen statement", visible=False)
all_citations = gr.HTML(label="All citations", visible=False)
submit_btn.click(run_llm, [context, query], [statements, answer, all_citations, cite_num2idx, citation_choices, clicked_citations])
citation_choices.change(chose_citation, [statements, cite_num2idx, citation_choices], [answer, clicked_citations])
demo.queue()
demo.launch()