Spaces:
Running
on
Zero
Running
on
Zero
import subprocess | |
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
import docx | |
import PyPDF2 | |
def convert_to_txt(file): | |
doc_type = file.split(".")[-1].strip() | |
if doc_type in ["txt", "md", "py"]: | |
data = [file.read().decode('utf-8')] | |
elif doc_type in ["pdf"]: | |
pdf_reader = PyPDF2.PdfReader(file) | |
data = [pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages))] | |
elif doc_type in ["docx"]: | |
doc = docx.Document(file) | |
data = [p.text for p in doc.paragraphs] | |
else: | |
raise gr.Error(f"ERROR: unsupported document type: {doc_type}") | |
text = "\n\n".join(data) | |
return text | |
model_name = "THUDM/LongCite-glm4-9b" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map='auto') | |
html_styles = """<style> | |
.reference { | |
color: blue; | |
text-decoration: underline; | |
} | |
.highlight { | |
background-color: yellow; | |
} | |
.label { | |
font-family: sans-serif; | |
font-size: 16px; | |
font-weight: bold; | |
} | |
.Bold { | |
font-weight: bold; | |
} | |
.statement { | |
background-color: lightgrey; | |
} | |
</style>\n""" | |
def process_text(text): | |
special_char={ | |
'&': '&', | |
'\'': ''', | |
'"': '"', | |
'<': '<', | |
'>': '>', | |
'\n': '<br>', | |
} | |
for x, y in special_char.items(): | |
text = text.replace(x, y) | |
return text | |
def convert_to_html(statements, clicked=-1): | |
html = html_styles + '<br><span class="label">Answer:</span><br>\n' | |
all_cite_html = [] | |
clicked_cite_html = None | |
cite_num2idx = {} | |
idx = 0 | |
for i, js in enumerate(statements): | |
statement, citations = process_text(js['statement']), js['citation'] | |
if clicked == i: | |
html += f"""<span class="statement">{statement}</span>""" | |
else: | |
html += f"<span>{statement}</span>" | |
if citations: | |
cite_html = [] | |
idxs = [] | |
for c in citations: | |
idx += 1 | |
idxs.append(str(idx)) | |
cite = '[Sentence: {}-{}\t|\tChar: {}-{}]<br>\n<span {}>{}</span>'.format(c['start_sentence_idx'], c['end_sentence_idx'], c['start_char_idx'], c['end_char_idx'], 'class="highlight"' if clicked==i else "", process_text(c['cite'].strip())) | |
cite_html.append(f"""<span><span class="Bold">Snippet [{idx}]:</span><br>{cite}</span>""") | |
all_cite_html.extend(cite_html) | |
cite_num = '[{}]'.format(','.join(idxs)) | |
cite_num2idx[cite_num] = i | |
cite_num_html = """ <span class="reference" style="color: blue" id={}>{}</span>""".format(i, cite_num) | |
html += cite_num_html | |
html += '\n' | |
if clicked == i: | |
clicked_cite_html = html_styles + """<br><span class="label">Citations of current statement:</span><br><div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format("<br><br>\n".join(cite_html)) | |
all_cite_html = html_styles + """<br><span class="label">All citations:</span><br>\n<div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format("<br><br>\n".join(all_cite_html).replace('<span class="highlight">', '<span>') if len(all_cite_html) else "No citation in the answer") | |
return html, all_cite_html, clicked_cite_html, cite_num2idx | |
def render_context(file): | |
if hasattr(file, "name"): | |
context = convert_to_txt(file.name) | |
return gr.Textbox(context, visible=True) | |
else: | |
raise gr.Error(f"ERROR: no uploaded document") | |
def run_llm(context, query): | |
if not context: | |
raise gr.Error("Error: no uploaded document") | |
if not query: | |
raise gr.Error("Error: no query") | |
result = model.query_longcite(context, query, tokenizer=tokenizer, max_input_length=128000, max_new_tokens=1024) | |
all_statements = result['all_statements'] | |
answer_html, all_cite_html, clicked_cite_html, cite_num2idx_dict = convert_to_html(all_statements) | |
cite_nums = list(cite_num2idx_dict.keys()) | |
return { | |
statements: gr.JSON(all_statements), | |
answer: gr.HTML(answer_html, visible=True), | |
all_citations: gr.HTML(all_cite_html, visible=True), | |
cite_num2idx: gr.JSON(cite_num2idx_dict), | |
citation_choices: gr.Radio(cite_nums, visible=len(cite_nums)>0), | |
clicked_citations: gr.HTML(visible=False), | |
} | |
def chose_citation(statements, cite_num2idx, clicked_cite_num): | |
clicked = cite_num2idx[clicked_cite_num] | |
answer_html, _, clicked_cite_html, _ = convert_to_html(statements, clicked=clicked) | |
return { | |
answer: gr.HTML(answer_html, visible=True), | |
clicked_citations: gr.HTML(clicked_cite_html, visible=True), | |
} | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
LongCite-glm4-9b Huggingface Space🤗 | |
</div> | |
<div style="text-align: center;"> | |
<a href="https://huggingface.co/THUDM/LongCite-glm4-9b">🤗 Model Hub</a> | | |
<a href="https://github.com/THUDM/LongCite">🌐 Github</a> | | |
<a href="https://arxiv.org/pdf/">📜 arxiv </a> | |
</div> | |
<br> | |
<div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;"> | |
If you plan to use it long-term, please consider deploying the model or forking this space yourself. | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
file = gr.File(label="Upload a document (supported type: pdf, docx, txt, md, py)") | |
query = gr.Textbox(label='Question') | |
submit_btn = gr.Button("Submit") | |
with gr.Column(scale=4): | |
context = gr.Textbox(label="Document content", autoscroll=False, placeholder="No uploaded document.", max_lines=10, visible=False) | |
file.upload(render_context, [file], [context]) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
statements = gr.JSON(label="statements", visible=False) | |
answer = gr.HTML(label="Answer", visible=True) | |
cite_num2idx = gr.JSON(label="cite_num2idx", visible=False) | |
citation_choices = gr.Radio(label="Chose citations for details", visible=False, interactive=True) | |
with gr.Column(scale=4): | |
clicked_citations = gr.HTML(label="Citations of the chosen statement", visible=False) | |
all_citations = gr.HTML(label="All citations", visible=False) | |
submit_btn.click(run_llm, [context, query], [statements, answer, all_citations, cite_num2idx, citation_choices, clicked_citations]) | |
citation_choices.change(chose_citation, [statements, cite_num2idx, citation_choices], [answer, clicked_citations]) | |
demo.queue() | |
demo.launch() |