Spaces:
Sleeping
Sleeping
File size: 2,040 Bytes
5760b44 f42ec01 b1106e6 5760b44 b1106e6 35c0239 b1106e6 35c0239 b1106e6 35c0239 f42ec01 35c0239 5760b44 b1106e6 f42ec01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
class BaselineCommaFixer:
def __init__(self):
self._ner = _create_baseline_pipeline()
def fix_commas(self, s: str) -> str:
return _fix_commas_based_on_pipeline_output(
self._ner(_remove_punctuation(s)),
s
)
def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
return pipeline('ner', model=model, tokenizer=tokenizer)
def _remove_punctuation(s: str) -> str:
to_remove = ".,?-:"
for char in to_remove:
s = s.replace(char, '')
return s
def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s: str) -> str:
result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
current_offset = 0
for i in range(1, len(pipeline_json)):
current_offset = _find_current_token(current_offset, i, pipeline_json, result)
if _should_insert_comma(i, pipeline_json):
result = result[:current_offset] + ',' + result[current_offset:]
current_offset += 1
return result
def _should_insert_comma(i, pipeline_json, new_word_indicator='▁') -> bool:
# Only insert commas for the final token of a word
return pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith(new_word_indicator)
def _find_current_token(current_offset, i, pipeline_json, result, new_word_indicator='▁') -> int:
current_word = pipeline_json[i - 1]['word'].replace(new_word_indicator, '')
# Find the current word in the result string, starting looking at current offset
current_offset = result.find(current_word, current_offset) + len(current_word)
return current_offset
if __name__ == "__main__":
BaselineCommaFixer() # to pre-download the model and tokenizer
|