Spaces:
Runtime error
Runtime error
Replicate default cc_net preprocessing at inference time on KenlmModel.get_perplexity
Browse files
perplexity_lenses/perplexity.py
CHANGED
@@ -1,10 +1,53 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import urllib.request
|
|
|
3 |
|
4 |
import kenlm
|
5 |
|
6 |
|
7 |
class KenlmModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
def __init__(self, language):
|
9 |
download_kenlm_model(language)
|
10 |
try:
|
@@ -19,7 +62,9 @@ class KenlmModel:
|
|
19 |
def from_pretrained(cls, language: str):
|
20 |
return cls(language)
|
21 |
|
22 |
-
def get_perplexity(self, doc: str):
|
|
|
|
|
23 |
doc_log_score, doc_length = 0, 0
|
24 |
for line in doc.split("\n"):
|
25 |
log_score = self.model.score(line)
|
@@ -28,6 +73,48 @@ class KenlmModel:
|
|
28 |
doc_length += length
|
29 |
return 10.0 ** (-doc_log_score / doc_length)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def download_kenlm_model(language: str):
|
33 |
root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
+
import unicodedata
|
4 |
import urllib.request
|
5 |
+
from typing import Dict
|
6 |
|
7 |
import kenlm
|
8 |
|
9 |
|
10 |
class KenlmModel:
|
11 |
+
digit_re: re.Pattern = re.compile(r"\d")
|
12 |
+
unicode_punct: Dict[str, str] = {
|
13 |
+
",": ",",
|
14 |
+
"。": ".",
|
15 |
+
"、": ",",
|
16 |
+
"„": '"',
|
17 |
+
"”": '"',
|
18 |
+
"“": '"',
|
19 |
+
"«": '"',
|
20 |
+
"»": '"',
|
21 |
+
"1": '"',
|
22 |
+
"」": '"',
|
23 |
+
"「": '"',
|
24 |
+
"《": '"',
|
25 |
+
"》": '"',
|
26 |
+
"´": "'",
|
27 |
+
"∶": ":",
|
28 |
+
":": ":",
|
29 |
+
"?": "?",
|
30 |
+
"!": "!",
|
31 |
+
"(": "(",
|
32 |
+
")": ")",
|
33 |
+
";": ";",
|
34 |
+
"–": "-",
|
35 |
+
"—": " - ",
|
36 |
+
".": ". ",
|
37 |
+
"~": "~",
|
38 |
+
"’": "'",
|
39 |
+
"…": "...",
|
40 |
+
"━": "-",
|
41 |
+
"〈": "<",
|
42 |
+
"〉": ">",
|
43 |
+
"【": "[",
|
44 |
+
"】": "]",
|
45 |
+
"%": "%",
|
46 |
+
"►": "-",
|
47 |
+
}
|
48 |
+
unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
|
49 |
+
non_printing_chars_re = re.compile(f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]")
|
50 |
+
|
51 |
def __init__(self, language):
|
52 |
download_kenlm_model(language)
|
53 |
try:
|
|
|
62 |
def from_pretrained(cls, language: str):
|
63 |
return cls(language)
|
64 |
|
65 |
+
def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
|
66 |
+
if normalize_cc_net:
|
67 |
+
doc = self.normalize(doc)
|
68 |
doc_log_score, doc_length = 0, 0
|
69 |
for line in doc.split("\n"):
|
70 |
log_score = self.model.score(line)
|
|
|
73 |
doc_length += length
|
74 |
return 10.0 ** (-doc_log_score / doc_length)
|
75 |
|
76 |
+
def normalize(
|
77 |
+
self,
|
78 |
+
line: str,
|
79 |
+
accent: bool = True,
|
80 |
+
case: bool = True,
|
81 |
+
numbers: bool = True,
|
82 |
+
punct: int = 1,
|
83 |
+
) -> str:
|
84 |
+
line = line.strip()
|
85 |
+
if not line:
|
86 |
+
return line
|
87 |
+
if case:
|
88 |
+
line = line.lower()
|
89 |
+
if accent:
|
90 |
+
line = self.strip_accents(line)
|
91 |
+
if numbers:
|
92 |
+
line = self.digit_re.sub("0", line)
|
93 |
+
if punct == 1:
|
94 |
+
line = self.replace_unicode_punct(line)
|
95 |
+
elif punct == 2:
|
96 |
+
line = self.remove_unicode_punct(line)
|
97 |
+
line = self.remove_non_printing_char(line)
|
98 |
+
return line
|
99 |
+
|
100 |
+
def strip_accents(self, line: str) -> str:
|
101 |
+
"""Strips accents from a piece of text."""
|
102 |
+
nfd = unicodedata.normalize("NFD", line)
|
103 |
+
output = [c for c in nfd if unicodedata.category(c) != "Mn"]
|
104 |
+
if len(output) == line:
|
105 |
+
return line
|
106 |
+
return "".join(output)
|
107 |
+
|
108 |
+
def replace_unicode_punct(self, text: str) -> str:
|
109 |
+
return "".join((self.unicode_punct.get(c, c) for c in text))
|
110 |
+
|
111 |
+
def remove_unicode_punct(self, text: str) -> str:
|
112 |
+
"""More aggressive version of replace_unicode_punct but also faster."""
|
113 |
+
return self.unicode_punct_re.sub("", text)
|
114 |
+
|
115 |
+
def remove_non_printing_char(self, text: str) -> str:
|
116 |
+
return self.non_printing_chars_re.sub("", text)
|
117 |
+
|
118 |
|
119 |
def download_kenlm_model(language: str):
|
120 |
root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
|