Update README.md
Browse files
README.md
CHANGED
@@ -2,5 +2,141 @@
|
|
2 |
license: bigscience-openrail-m
|
3 |
widget:
|
4 |
- text: >-
|
5 |
-
wnt signalling orchestrates a number of developmental programs in response
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: bigscience-openrail-m
|
3 |
widget:
|
4 |
- text: >-
|
5 |
+
wnt signalling orchestrates a number of developmental programs in response
|
6 |
+
to this stimulus cytoplasmic beta catenin (encoded by ctnnb1) is stabilized
|
7 |
+
enabling downstream transcriptional activation by members of the lef/tcf
|
8 |
+
family
|
9 |
+
datasets:
|
10 |
+
- bigbio/drugprot
|
11 |
+
- bigbio/ncbi_disease
|
12 |
+
language:
|
13 |
+
- en
|
14 |
+
pipeline_tag: token-classification
|
15 |
+
tags:
|
16 |
+
- biology
|
17 |
+
- medical
|
18 |
+
---
|
19 |
+
|
20 |
+
# DistilBERT base model for restoring punctuation of medical/biotech speed-to-text transcripts
|
21 |
+
E.g.:
|
22 |
+
```
|
23 |
+
EXAMPLE
|
24 |
+
```
|
25 |
+
will be punctuated as follows:
|
26 |
+
```
|
27 |
+
EXAMPLE
|
28 |
+
```
|
29 |
+
|
30 |
+
## How to use it in your code:
|
31 |
+
```python
|
32 |
+
import torch
|
33 |
+
import numpy as np
|
34 |
+
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification
|
35 |
+
|
36 |
+
checkpoint = "unikei/distilbert-base-re-punctuate"
|
37 |
+
tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)
|
38 |
+
model = DistilBertForTokenClassification.from_pretrained(checkpoint)
|
39 |
+
encoder_max_length = 256
|
40 |
+
|
41 |
+
#
|
42 |
+
# Split text to segments of length 200, with overlap 50
|
43 |
+
#
|
44 |
+
def split_to_segments(wrds, length, overlap):
|
45 |
+
resp = []
|
46 |
+
i = 0
|
47 |
+
while True:
|
48 |
+
wrds_split = wrds[(length * i):((length * (i + 1)) + overlap)]
|
49 |
+
if not wrds_split:
|
50 |
+
break
|
51 |
+
|
52 |
+
resp_obj = {
|
53 |
+
"text": wrds_split,
|
54 |
+
"start_idx": length * i,
|
55 |
+
"end_idx": (length * (i + 1)) + overlap,
|
56 |
+
}
|
57 |
+
|
58 |
+
resp.append(resp_obj)
|
59 |
+
i += 1
|
60 |
+
return resp
|
61 |
+
|
62 |
+
|
63 |
+
#
|
64 |
+
# Punctuate wordpieces
|
65 |
+
#
|
66 |
+
def punctuate_wordpiece(wordpiece, label):
|
67 |
+
if label.startswith('UPPER'):
|
68 |
+
wordpiece = wordpiece.upper()
|
69 |
+
elif label.startswith('Upper'):
|
70 |
+
wordpiece = wordpiece[0].upper() + wordpiece[1:]
|
71 |
+
if label[-1] != '_' and label[-1] != wordpiece[-1]:
|
72 |
+
wordpiece += label[-1]
|
73 |
+
return wordpiece
|
74 |
+
|
75 |
+
|
76 |
+
#
|
77 |
+
# Punctuate text segments (200 words)
|
78 |
+
#
|
79 |
+
def punctuate_segment(wordpieces, word_ids, labels, start_word):
|
80 |
+
result = ''
|
81 |
+
for idx in range(0, len(wordpieces)):
|
82 |
+
if word_ids[idx] == None:
|
83 |
+
continue
|
84 |
+
if word_ids[idx] < start_word:
|
85 |
+
continue
|
86 |
+
wordpiece = punctuate_wordpiece(wordpieces[idx][2:] if wordpieces[idx].startswith('##') else wordpieces[idx],
|
87 |
+
labels[idx])
|
88 |
+
if idx > 0 and len(result) > 0 and word_ids[idx] != word_ids[idx - 1] and result[-1] != '-':
|
89 |
+
result += ' '
|
90 |
+
result += wordpiece
|
91 |
+
return result
|
92 |
+
|
93 |
+
|
94 |
+
#
|
95 |
+
# Tokenize, predict, punctuate text segments (200 words)
|
96 |
+
#
|
97 |
+
def process_segment(words, tokenizer, model, start_word):
|
98 |
+
|
99 |
+
tokens = tokenizer(words['text'],
|
100 |
+
padding="max_length",
|
101 |
+
# truncation=True,
|
102 |
+
max_length=encoder_max_length,
|
103 |
+
is_split_into_words=True, return_tensors='pt')
|
104 |
+
|
105 |
+
with torch.no_grad():
|
106 |
+
logits = model(**tokens).logits
|
107 |
+
logits = logits.cpu()
|
108 |
+
predictions = np.argmax(logits, axis=-1)
|
109 |
+
|
110 |
+
wordpieces = tokens.tokens()
|
111 |
+
word_ids = tokens.word_ids()
|
112 |
+
id2label = model.config.id2label
|
113 |
+
labels = [[id2label[p.item()] for p in prediction] for prediction in predictions][0]
|
114 |
+
|
115 |
+
return punctuate_segment(wordpieces, word_ids, labels, start_word)
|
116 |
+
|
117 |
+
|
118 |
+
#
|
119 |
+
# Punctuate text of any length
|
120 |
+
#
|
121 |
+
def punctuate(text, tokenizer, model):
|
122 |
+
text = text.lower()
|
123 |
+
text = text.replace('\n', ' ')
|
124 |
+
words = text.split(' ')
|
125 |
+
|
126 |
+
overlap = 50
|
127 |
+
slices = split_to_segments(words, 150, 50)
|
128 |
+
|
129 |
+
result = ""
|
130 |
+
start_word = 0
|
131 |
+
for text in slices:
|
132 |
+
corrected = process_segment(text, tokenizer, model, start_word)
|
133 |
+
result += corrected + ' '
|
134 |
+
start_word = overlap
|
135 |
+
return result
|
136 |
+
|
137 |
+
#
|
138 |
+
# Example
|
139 |
+
#
|
140 |
+
text = ""
|
141 |
+
result = punctuate(text, tokenizer, model)
|
142 |
+
print(result)
|