SciAssist / dataset_extraction.py
wing-nus's picture
Update dataset_extraction.py
516a961
from typing import List, Tuple
import torch
import nltk
from SciAssist import DatasetExtraction
device = "gpu" if torch.cuda.is_available() else "cpu"
de_pipeline = DatasetExtraction(os_name="nt", device=device)
def de_for_str(input):
list_input = nltk.sent_tokenize(input)
results = de_pipeline.extract(list_input, type="str", save_results=False)
# output = []
# for res in results["dataset_mentions"]:
# output.append(f"{res}\n\n")
# return "".join(output)
output = []
for mention_pair in results["dataset_mentions"]:
output.append((mention_pair[0], mention_pair[1]))
output.append(("\n\n", None))
return output
def de_for_file(input):
if input == None:
return None
filename = input.name
# Identify the format of input and parse reference strings
if filename[-4:] == ".txt":
results = de_pipeline.extract(filename, type="txt", save_results=False)
elif filename[-4:] == ".pdf":
results = de_pipeline.extract(filename, type="pdf", save_results=False)
else:
return [("File Format Error !", None)]
output = []
for mention_pair in results["dataset_mentions"]:
output.append((mention_pair[0], mention_pair[1]))
output.append(("\n\n", None))
return output
de_str_example = "BAKIS incorporates information derived from the bank balance sheets and supervisory reports of all German banks ."