wing-nus commited on
Commit
f870b18
1 Parent(s): 056e35b

Update reference_string_parsing.py

Browse files
Files changed (1) hide show
  1. reference_string_parsing.py +36 -36
reference_string_parsing.py CHANGED
@@ -1,36 +1,36 @@
1
- from typing import List, Tuple
2
- import torch
3
- from SciAssist import ReferenceStringParsing
4
-
5
- device = "gpu" if torch.cuda.is_available() else "cpu"
6
- rsp_pipeline = ReferenceStringParsing(os_name="nt")
7
-
8
-
9
- def rsp_for_str(input, dehyphen=False) -> List[Tuple[str, str]]:
10
- results = rsp_pipeline.predict(input, type="str", dehyphen=dehyphen)
11
- output = []
12
- for res in results:
13
- for token, tag in zip(res["tokens"], res["tags"]):
14
- output.append((token, tag))
15
- output.append(("\n\n", None))
16
- return output
17
-
18
-
19
- def rsp_for_file(input, dehyphen=False) -> List[Tuple[str, str]]:
20
- if input == None:
21
- return None
22
- filename = input.name
23
- # Identify the format of input and parse reference strings
24
- if filename[-4:] == ".txt":
25
- results = rsp_pipeline.predict(filename, type="txt", dehyphen=dehyphen, save_results=False)
26
- elif filename[-4:] == ".pdf":
27
- results = rsp_pipeline.predict(filename, dehyphen=dehyphen, save_results=False)
28
- else:
29
- return [("File Format Error !", None)]
30
- # Prepare for the input gradio.HighlightedText accepts.
31
- output = []
32
- for res in results:
33
- for token, tag in zip(res["tokens"], res["tags"]):
34
- output.append((token, tag))
35
- output.append(("\n\n", None))
36
- return output
 
1
+ from typing import List, Tuple
2
+ import torch
3
+ from SciAssist import ReferenceStringParsing
4
+
5
+ device = "gpu" if torch.cuda.is_available() else "cpu"
6
+ rsp_pipeline = ReferenceStringParsing(os_name="nt", device=device)
7
+
8
+
9
+ def rsp_for_str(input, dehyphen=False) -> List[Tuple[str, str]]:
10
+ results = rsp_pipeline.predict(input, type="str", dehyphen=dehyphen)
11
+ output = []
12
+ for res in results:
13
+ for token, tag in zip(res["tokens"], res["tags"]):
14
+ output.append((token, tag))
15
+ output.append(("\n\n", None))
16
+ return output
17
+
18
+
19
+ def rsp_for_file(input, dehyphen=False) -> List[Tuple[str, str]]:
20
+ if input == None:
21
+ return None
22
+ filename = input.name
23
+ # Identify the format of input and parse reference strings
24
+ if filename[-4:] == ".txt":
25
+ results = rsp_pipeline.predict(filename, type="txt", dehyphen=dehyphen, save_results=False)
26
+ elif filename[-4:] == ".pdf":
27
+ results = rsp_pipeline.predict(filename, dehyphen=dehyphen, save_results=False)
28
+ else:
29
+ return [("File Format Error !", None)]
30
+ # Prepare for the input gradio.HighlightedText accepts.
31
+ output = []
32
+ for res in results:
33
+ for token, tag in zip(res["tokens"], res["tags"]):
34
+ output.append((token, tag))
35
+ output.append(("\n\n", None))
36
+ return output