File size: 2,040 Bytes
5760b44
 
 
f42ec01
 
 
 
 
 
 
 
 
 
 
 
b1106e6
 
5760b44
 
 
 
 
 
 
 
 
 
b1106e6
 
35c0239
b1106e6
35c0239
b1106e6
 
35c0239
f42ec01
35c0239
5760b44
 
b1106e6
 
 
 
 
 
 
 
 
 
f42ec01
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline


class BaselineCommaFixer:
    def __init__(self):
        self._ner = _create_baseline_pipeline()

    def fix_commas(self, s: str) -> str:
        return _fix_commas_based_on_pipeline_output(
            self._ner(_remove_punctuation(s)),
            s
        )


def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    return pipeline('ner', model=model, tokenizer=tokenizer)


def _remove_punctuation(s: str) -> str:
    to_remove = ".,?-:"
    for char in to_remove:
        s = s.replace(char, '')
    return s


def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s: str) -> str:
    result = original_s.replace(',', '')  # We will fix the commas, but keep everything else intact
    current_offset = 0

    for i in range(1, len(pipeline_json)):
        current_offset = _find_current_token(current_offset, i, pipeline_json, result)
        if _should_insert_comma(i, pipeline_json):
            result = result[:current_offset] + ',' + result[current_offset:]
        current_offset += 1
    return result


def _should_insert_comma(i, pipeline_json, new_word_indicator='▁') -> bool:
    # Only insert commas for the final token of a word
    return pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith(new_word_indicator)


def _find_current_token(current_offset, i, pipeline_json, result, new_word_indicator='▁') -> int:
    current_word = pipeline_json[i - 1]['word'].replace(new_word_indicator, '')
    # Find the current word in the result string, starting looking at current offset
    current_offset = result.find(current_word, current_offset) + len(current_word)
    return current_offset


if __name__ == "__main__":
    BaselineCommaFixer()  # to pre-download the model and tokenizer