SciAssist / reference_string_parsing.py
wing-nus's picture
Update reference_string_parsing.py
f870b18
from typing import List, Tuple
import torch
from SciAssist import ReferenceStringParsing
device = "gpu" if torch.cuda.is_available() else "cpu"
rsp_pipeline = ReferenceStringParsing(os_name="nt", device=device)
def rsp_for_str(input, dehyphen=False) -> List[Tuple[str, str]]:
results = rsp_pipeline.predict(input, type="str", dehyphen=dehyphen)
output = []
for res in results:
for token, tag in zip(res["tokens"], res["tags"]):
output.append((token, tag))
output.append(("\n\n", None))
return output
def rsp_for_file(input, dehyphen=False) -> List[Tuple[str, str]]:
if input == None:
return None
filename = input.name
# Identify the format of input and parse reference strings
if filename[-4:] == ".txt":
results = rsp_pipeline.predict(filename, type="txt", dehyphen=dehyphen, save_results=False)
elif filename[-4:] == ".pdf":
results = rsp_pipeline.predict(filename, dehyphen=dehyphen, save_results=False)
else:
return [("File Format Error !", None)]
# Prepare for the input gradio.HighlightedText accepts.
output = []
for res in results:
for token, tag in zip(res["tokens"], res["tags"]):
output.append((token, tag))
output.append(("\n\n", None))
return output