import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True)
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
import docx
import PyPDF2
import spaces
def convert_to_txt(file):
doc_type = file.split(".")[-1].strip()
if doc_type in ["txt", "md", "py"]:
data = [file.read().decode("utf-8")]
elif doc_type in ["pdf"]:
pdf_reader = PyPDF2.PdfReader(file)
data = [
pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages))
]
elif doc_type in ["docx"]:
doc = docx.Document(file)
data = [p.text for p in doc.paragraphs]
else:
raise gr.Error(f"ERROR: unsupported document type: {doc_type}")
text = "\n\n".join(data)
return text
model_name = "THUDM/LongCite-glm4-9b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device="cuda",
attn_implementation="flash_attention_2",
)
html_styles = """\n"""
def process_text(text):
special_char = {
"&": "&",
"'": "'",
'"': """,
"<": "<",
">": ">",
"\n": "
",
}
for x, y in special_char.items():
text = text.replace(x, y)
return text
def convert_to_html(statements, clicked=-1):
html = html_styles + '
Answer:
\n'
all_cite_html = []
clicked_cite_html = None
cite_num2idx = {}
idx = 0
for i, js in enumerate(statements):
statement, citations = process_text(js["statement"]), js["citation"]
if clicked == i:
html += f"""{statement}"""
else:
html += f"{statement}"
if citations:
cite_html = []
idxs = []
for c in citations:
idx += 1
idxs.append(str(idx))
cite = (
"[Sentence: {}-{}\t|\tChar: {}-{}]
\n{}".format(
c["start_sentence_idx"],
c["end_sentence_idx"],
c["start_char_idx"],
c["end_char_idx"],
'class="highlight"' if clicked == i else "",
process_text(c["cite"].strip()),
)
)
cite_html.append(
f"""Snippet [{idx}]:
{cite}"""
)
all_cite_html.extend(cite_html)
cite_num = "[{}]".format(",".join(idxs))
cite_num2idx[cite_num] = i
cite_num_html = """ {}""".format(
i, cite_num
)
html += cite_num_html
html += "\n"
if clicked == i:
clicked_cite_html = (
html_styles
+ """
Citations of current statement: