import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, ) import docx import PyPDF2 import spaces def convert_to_txt(file): doc_type = file.split(".")[-1].strip() if doc_type in ["txt", "md", "py"]: data = [file.read().decode("utf-8")] elif doc_type in ["pdf"]: pdf_reader = PyPDF2.PdfReader(file) data = [ pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages)) ] elif doc_type in ["docx"]: doc = docx.Document(file) data = [p.text for p in doc.paragraphs] else: raise gr.Error(f"ERROR: unsupported document type: {doc_type}") text = "\n\n".join(data) return text model_name = "THUDM/LongCite-glm4-9b" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device="cuda", attn_implementation="flash_attention_2", ) html_styles = """\n""" def process_text(text): special_char = { "&": "&", "'": "'", '"': """, "<": "<", ">": ">", "\n": "
", } for x, y in special_char.items(): text = text.replace(x, y) return text def convert_to_html(statements, clicked=-1): html = html_styles + '
Answer:
\n' all_cite_html = [] clicked_cite_html = None cite_num2idx = {} idx = 0 for i, js in enumerate(statements): statement, citations = process_text(js["statement"]), js["citation"] if clicked == i: html += f"""{statement}""" else: html += f"{statement}" if citations: cite_html = [] idxs = [] for c in citations: idx += 1 idxs.append(str(idx)) cite = ( "[Sentence: {}-{}\t|\tChar: {}-{}]
\n{}".format( c["start_sentence_idx"], c["end_sentence_idx"], c["start_char_idx"], c["end_char_idx"], 'class="highlight"' if clicked == i else "", process_text(c["cite"].strip()), ) ) cite_html.append( f"""Snippet [{idx}]:
{cite}
""" ) all_cite_html.extend(cite_html) cite_num = "[{}]".format(",".join(idxs)) cite_num2idx[cite_num] = i cite_num_html = """ {}""".format( i, cite_num ) html += cite_num_html html += "\n" if clicked == i: clicked_cite_html = ( html_styles + """
Citations of current statement:
{}
""".format( "

\n".join(cite_html) ) ) all_cite_html = ( html_styles + """
All citations:
\n
{}
""".format( "

\n".join(all_cite_html).replace( '', "" ) if len(all_cite_html) else "No citation in the answer" ) ) return html, all_cite_html, clicked_cite_html, cite_num2idx def render_context(file): if hasattr(file, "name"): context = convert_to_txt(file.name) return gr.Textbox(context, visible=True) else: raise gr.Error(f"ERROR: no uploaded document") @spaces.GPU(duration=120) def infer(context, query): return model.query_longcite( context=context, query=query, tokenizer=tokenizer, max_input_length=128000, max_new_tokens=1024, ) def run_llm(context, query): if not context: raise gr.Error("Error: no uploaded document") if not query: raise gr.Error("Error: no query") result = infer(context=context, query=query) all_statements = result["all_statements"] answer_html, all_cite_html, clicked_cite_html, cite_num2idx_dict = convert_to_html( all_statements ) cite_nums = list(cite_num2idx_dict.keys()) return { statements: gr.JSON(all_statements), answer: gr.HTML(answer_html, visible=True), all_citations: gr.HTML(all_cite_html, visible=True), cite_num2idx: gr.JSON(cite_num2idx_dict), citation_choices: gr.Radio(cite_nums, visible=len(cite_nums) > 0), clicked_citations: gr.HTML(visible=False), } def chose_citation(statements, cite_num2idx, clicked_cite_num): clicked = cite_num2idx[clicked_cite_num] answer_html, _, clicked_cite_html, _ = convert_to_html(statements, clicked=clicked) return { answer: gr.HTML(answer_html, visible=True), clicked_citations: gr.HTML(clicked_cite_html, visible=True), } with gr.Blocks() as demo: gr.Markdown( """
LongCite-glm4-9b Huggingface Space🤗
🤗 Model Hub | 🌐 Github | 📜 arxiv

If you plan to use it long-term, please consider deploying the model or forking this space yourself.
""" ) with gr.Row(): with gr.Column(scale=4): file = gr.File( label="Upload a document (supported type: pdf, docx, txt, md, py)" ) query = gr.Textbox(label="Question") submit_btn = gr.Button("Submit") with gr.Column(scale=4): context = gr.Textbox( label="Document content", autoscroll=False, placeholder="No uploaded document.", max_lines=10, visible=False, ) file.upload(render_context, [file], [context]) with gr.Row(): with gr.Column(scale=4): statements = gr.JSON(label="statements", visible=False) answer = gr.HTML(label="Answer", visible=True) cite_num2idx = gr.JSON(label="cite_num2idx", visible=False) citation_choices = gr.Radio( label="Chose citations for details", visible=False, interactive=True ) with gr.Column(scale=4): clicked_citations = gr.HTML( label="Citations of the chosen statement", visible=False ) all_citations = gr.HTML(label="All citations", visible=False) submit_btn.click( run_llm, [context, query], [ statements, answer, all_citations, cite_num2idx, citation_choices, clicked_citations, ], ) citation_choices.change( chose_citation, [statements, cite_num2idx, citation_choices], [answer, clicked_citations], ) demo.queue() demo.launch()