kargaranamir
commited on
Commit
•
2c9efe4
1
Parent(s):
d8f31aa
add app.
Browse files- README.md +5 -5
- app.py +78 -0
- masklid.py +268 -0
- requirements.txt +3 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: MaskLID
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
1 |
---
|
2 |
title: MaskLID
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
This code applies [MaskLID](https://arxiv.org/abs/2406.06263) with [GlotLID](https://arxiv.org/abs/2310.16248), a fasttext-based language identification tool.
|
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Amir Hossein Kargaran
|
2 |
+
# Date: August, 2023
|
3 |
+
|
4 |
+
# Description: This code applies MaskLID (code-switch language identification) with GlotLID, a fastText-based language identification tool.
|
5 |
+
|
6 |
+
# MIT License
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
from masklid import MaskLID
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from fasttext.FastText import _FastText
|
12 |
+
|
13 |
+
def render_metadata():
|
14 |
+
"""Renders the metadata."""
|
15 |
+
html_content = """
|
16 |
+
<p align="center">
|
17 |
+
<a href="https://github.com/cisnlp/MaskLID"><img alt="GitHub stars" src="https://img.shields.io/github/stars/cisnlp/MaskLID"></a>
|
18 |
+
This is the demo for <a href="https://arxiv.org/abs/2406.06263">MaskLID</a> paper (ACL 2024). You can see the whole code in our GitHub. Please also note that if you increase the number of languages, you also need larger alpha and beta values.
|
19 |
+
MaskLID does not add much overhead to language identification. You first fix the languages your model is limited to and then run the MaskLID code. However, in this demo, we load the model each time (that takes couple of seconds) you hit submit to ensure the results are not cached and to make it possible to change the set of languages each time. We may later change the demo code to resolve this.
|
20 |
+
</p>
|
21 |
+
"""
|
22 |
+
return html_content
|
23 |
+
|
24 |
+
|
25 |
+
def get_model_path():
|
26 |
+
# Download GlotLID FastText language identification model from Hugging Face Hub
|
27 |
+
model_path = hf_hub_download(repo_id="cis-lmu/glotlid", filename="model_v3.bin")
|
28 |
+
return model_path
|
29 |
+
|
30 |
+
|
31 |
+
def get_masklid():
|
32 |
+
# load masklid model
|
33 |
+
masklid_model = MaskLID(get_model_path())
|
34 |
+
|
35 |
+
# get all the labels
|
36 |
+
labels = masklid_model.model.get_labels()
|
37 |
+
labels = [l for l in labels if not l.startswith('__label__und') and not l.startswith('__label__zxx')]
|
38 |
+
|
39 |
+
return masklid_model, labels
|
40 |
+
|
41 |
+
def predict_codeswitch(text, top_labels=200, beta=20, alpha=3, max_lambda=3, min_length=10, min_prob=0.90, max_retry=3, alpha_step_increase=3, beta_step_increase=5):
|
42 |
+
|
43 |
+
# constraints
|
44 |
+
beta = top_labels if beta > top_labels else beta
|
45 |
+
alpha = beta if alpha > beta else alpha
|
46 |
+
|
47 |
+
# override the masklid label set
|
48 |
+
masklid_model, labels = get_masklid()
|
49 |
+
masklid_model.language_indices = masklid_model._compute_language_indices(labels[:top_labels])
|
50 |
+
masklid_model.labels = [masklid_model.model.get_labels()[i] for i in masklid_model.language_indices]
|
51 |
+
|
52 |
+
ans = masklid_model.predict_codeswitch(text, beta=beta, alpha=alpha, max_lambda=max_lambda, min_length=min_length, min_prob=min_prob, max_retry=max_retry, alpha_step_increase=alpha_step_increase, beta_step_increase=beta_step_increase)
|
53 |
+
|
54 |
+
return ans
|
55 |
+
|
56 |
+
inputs = gr.Textbox(lines=2, label="Enter the text", value="bir kahve dükkanında geçen film tadında güzel bir şarkıya ayrılsın gece falling in love at a coffee shop")
|
57 |
+
parameters = {
|
58 |
+
"top_labels": gr.Slider(minimum=2, maximum=len(get_masklid()[1]), step=1, value=200, label="Limit LID to X Top Languages"),
|
59 |
+
"beta": gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Beta"),
|
60 |
+
"alpha": gr.Slider(minimum=1, maximum=30, value=3, step=1, label="Alpha"),
|
61 |
+
"max_lambda": gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Iteration"),
|
62 |
+
"min_length": gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Min Length"),
|
63 |
+
"min_prob": gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Min Probability"),
|
64 |
+
"max_retry": gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Retry In total"),
|
65 |
+
"alpha_step_increase": gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Alpha Step Increase"),
|
66 |
+
"beta_step_increase": gr.Slider(minimum=1, maximum=15, value=5, step=1, label="Beta Step Increase")
|
67 |
+
}
|
68 |
+
|
69 |
+
output = gr.JSON(label="Output")
|
70 |
+
|
71 |
+
gr.Interface(
|
72 |
+
fn=predict_codeswitch,
|
73 |
+
inputs=[inputs, *parameters.values()],
|
74 |
+
outputs=output,
|
75 |
+
title="MaskLID (Code-Switch Language Identification)",
|
76 |
+
description = render_metadata(),
|
77 |
+
cache_examples=False
|
78 |
+
).launch()
|
masklid.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fasttext
|
2 |
+
import numpy as np
|
3 |
+
import re
|
4 |
+
import string
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
class MaskLID:
|
8 |
+
"""A class for code-switching language identification using iterative masking."""
|
9 |
+
|
10 |
+
def __init__(self, model_path, languages=-1):
|
11 |
+
"""Initialize the MaskLID class.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
model_path (str): The path to the fastText model.
|
15 |
+
languages (int or list, optional): The indices or list of language labels to consider. Defaults to -1.
|
16 |
+
"""
|
17 |
+
self.model = fasttext.load_model(model_path)
|
18 |
+
self.output_matrix = self.model.get_output_matrix()
|
19 |
+
self.labels = self.model.get_labels()
|
20 |
+
self.language_indices = self._compute_language_indices(languages)
|
21 |
+
self.labels = [self.labels[i] for i in self.language_indices]
|
22 |
+
|
23 |
+
def _compute_language_indices(self, languages):
|
24 |
+
"""Compute indices of selected languages.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
languages (int or list): The indices or list of language labels.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
list: Indices of selected languages.
|
31 |
+
"""
|
32 |
+
if languages != -1 and isinstance(languages, list):
|
33 |
+
return [self.labels.index(l) for l in set(languages) if l in self.labels]
|
34 |
+
return list(range(len(self.labels)))
|
35 |
+
|
36 |
+
def _softmax(self, x):
|
37 |
+
"""Compute softmax values for each score in array x.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
x (numpy.ndarray): Input array.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
numpy.ndarray: Softmax output.
|
44 |
+
"""
|
45 |
+
exp_x = np.exp(x - np.max(x))
|
46 |
+
return exp_x / np.sum(exp_x)
|
47 |
+
|
48 |
+
def _normalize_text(self, text):
|
49 |
+
"""Normalize input text.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
text (str): Input text.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
str: Normalized text.
|
56 |
+
"""
|
57 |
+
replace_by = " "
|
58 |
+
replacement_map = {ord(c): replace_by for c in '_:' + '•#{|}' + string.digits}
|
59 |
+
text = text.replace('\n', replace_by)
|
60 |
+
text = text.translate(replacement_map)
|
61 |
+
return re.sub(r'\s+', replace_by, text).strip()
|
62 |
+
|
63 |
+
def predict(self, text, k=1):
|
64 |
+
"""Predict the language of the input text.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
text (str): Input text.
|
68 |
+
k (int, optional): Number of top predictions to retrieve. Defaults to 1.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
tuple: Top predicted labels and their probabilities.
|
72 |
+
"""
|
73 |
+
sentence_vector = self.model.get_sentence_vector(text)
|
74 |
+
result_vector = np.dot(self.output_matrix, sentence_vector)
|
75 |
+
softmax_result = self._softmax(result_vector)[self.language_indices]
|
76 |
+
top_k_indices = np.argsort(softmax_result)[-k:][::-1]
|
77 |
+
top_k_labels = [self.labels[i] for i in top_k_indices]
|
78 |
+
top_k_probs = softmax_result[top_k_indices]
|
79 |
+
return tuple(top_k_labels), top_k_probs
|
80 |
+
|
81 |
+
def compute_v(self, sentence_vector):
|
82 |
+
"""Compute the language vectors for a given sentence vector.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
sentence_vector (numpy.ndarray): Sentence vector.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
list: Sorted list of labels and their associated vectors.
|
89 |
+
"""
|
90 |
+
result_vector = np.dot(self.output_matrix[self.language_indices, :], sentence_vector)
|
91 |
+
return sorted(zip(self.labels, result_vector), key=lambda x: x[1], reverse=True)
|
92 |
+
|
93 |
+
def compute_v_per_word(self, text):
|
94 |
+
"""Compute language vectors for each word in the input text.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
text (str): Input text.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
dict: Dictionary containing language vectors for each word.
|
101 |
+
"""
|
102 |
+
text = self._normalize_text(text)
|
103 |
+
words = self.model.get_line(text)[0]
|
104 |
+
words = [w for w in words if w not in ['</s>', '</s>']]
|
105 |
+
subword_ids = [self.model.get_subwords(sw)[1] for sw in words]
|
106 |
+
sentence_vector = [np.sum([self.model.get_input_vector(id) for id in sid], axis=0) for sid in subword_ids]
|
107 |
+
|
108 |
+
dict_text = {}
|
109 |
+
for i, word in enumerate(words):
|
110 |
+
key = f"{i}_{word}"
|
111 |
+
dict_text[key] = {'logits': self.compute_v(sentence_vector[i])}
|
112 |
+
|
113 |
+
return dict_text
|
114 |
+
|
115 |
+
def mask_label_top_k(self, dict_text, label, top_keep, top_remove):
|
116 |
+
"""Mask top predictions for a given label.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
dict_text (dict): Dictionary containing language vectors for each word.
|
120 |
+
label (str): Label to mask.
|
121 |
+
top_keep (int): Number of top predictions to keep.
|
122 |
+
top_remove (int): Number of top predictions to remove.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
tuple: Dictionaries of remaining and deleted words after masking.
|
126 |
+
"""
|
127 |
+
dict_remained = deepcopy(dict_text)
|
128 |
+
dict_deleted = {}
|
129 |
+
|
130 |
+
for key, value in dict_text.items():
|
131 |
+
logits = value['logits']
|
132 |
+
labels = [t[0] for t in logits]
|
133 |
+
|
134 |
+
if label in labels[:top_keep]:
|
135 |
+
dict_deleted[key] = dict_remained[key]
|
136 |
+
|
137 |
+
if label in labels[:top_remove]:
|
138 |
+
dict_remained.pop(key, None)
|
139 |
+
|
140 |
+
return dict_remained, dict_deleted
|
141 |
+
|
142 |
+
@staticmethod
|
143 |
+
def get_sizeof(text):
|
144 |
+
"""Compute the size of text in bytes.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
text (str): Input text.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
int: Size of text in bytes.
|
151 |
+
"""
|
152 |
+
return len(text.encode('utf-8'))
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def custom_sort(word):
|
156 |
+
"""Custom sorting function for words.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
word (str): Input word.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
int or float: Sorted value.
|
163 |
+
"""
|
164 |
+
match = re.match(r'^(\d+)_', word)
|
165 |
+
if match:
|
166 |
+
return int(match.group(1))
|
167 |
+
else:
|
168 |
+
return float('inf') # Return infinity for words without numbers at the beginning
|
169 |
+
|
170 |
+
def sum_logits(self, dict_data, label):
|
171 |
+
"""Compute the sum of logits for a specific label across all words.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
dict_data (dict): Dictionary containing language vectors for each word.
|
175 |
+
label (str): Label to sum logits for.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
float: Total sum of logits for the given label.
|
179 |
+
"""
|
180 |
+
total = 0
|
181 |
+
for value in dict_data.values():
|
182 |
+
logits = value['logits']
|
183 |
+
labels = [t[0] for t in logits]
|
184 |
+
if label in labels:
|
185 |
+
total += logits[labels.index(label)][1]
|
186 |
+
return total
|
187 |
+
|
188 |
+
def predict_codeswitch(self, text, beta, alpha, min_prob, min_length, max_lambda=1, max_retry=3, alpha_step_increase=5, beta_step_increase=5):
|
189 |
+
"""Predict language switching points in the input text.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
text (str): Input text.
|
193 |
+
beta (int): Number of top predictions to keep.
|
194 |
+
alpha (int): Number of top predictions to remove.
|
195 |
+
min_prob (float): Minimum probability threshold for language prediction.
|
196 |
+
min_length (int): Minimum length of text after masking.
|
197 |
+
max_lambda (int, optional): Maximum number of iterations. Defaults to 1.
|
198 |
+
max_retry (int, optional): Maximum number of retries. Defaults to 3.
|
199 |
+
alpha_step_increase (int, optional): Step increase for alpha. Defaults to 5.
|
200 |
+
beta_step_increase (int, optional): Step increase for beta. Defaults to 5.
|
201 |
+
Returns:
|
202 |
+
dict: Predicted language switching points and associated information.
|
203 |
+
"""
|
204 |
+
info = {}
|
205 |
+
index = 0
|
206 |
+
retry = 0
|
207 |
+
|
208 |
+
# compute v
|
209 |
+
dict_data = self.compute_v_per_word(text)
|
210 |
+
|
211 |
+
while index < max_lambda and retry < max_retry:
|
212 |
+
|
213 |
+
# predict the text
|
214 |
+
pred = self.predict(text, k=1)
|
215 |
+
label = pred[0][0]
|
216 |
+
|
217 |
+
# save the current text in case of step back
|
218 |
+
prev_text = text
|
219 |
+
# mask
|
220 |
+
dict_data, dict_masked = self.mask_label_top_k(dict_data, label, beta, alpha)
|
221 |
+
|
222 |
+
# get the text from the masked text and remained text
|
223 |
+
masked_text = ' '.join(x.split('_', 1)[1] for x in dict_masked.keys())
|
224 |
+
text = ' '.join(x.split('_', 1)[1] for x in dict_data.keys())
|
225 |
+
|
226 |
+
# save info
|
227 |
+
if self.get_sizeof(masked_text) > min_length or index == 0:
|
228 |
+
temp_pred = self.predict(masked_text)
|
229 |
+
|
230 |
+
if (temp_pred[1][0] > min_prob and temp_pred[0][0] == label) or index == 0:
|
231 |
+
info[index] = {
|
232 |
+
'label': label,
|
233 |
+
'text': masked_text,
|
234 |
+
'text_keys': dict_masked.keys(),
|
235 |
+
'size': self.get_sizeof(masked_text),
|
236 |
+
'sum_logit': self.sum_logits(dict_masked, label)
|
237 |
+
}
|
238 |
+
index += 1
|
239 |
+
else:
|
240 |
+
text = prev_text
|
241 |
+
beta += beta_step_increase
|
242 |
+
alpha += alpha_step_increase
|
243 |
+
retry += 1
|
244 |
+
else:
|
245 |
+
text = prev_text
|
246 |
+
beta += beta_step_increase
|
247 |
+
alpha += alpha_step_increase
|
248 |
+
retry += 1
|
249 |
+
|
250 |
+
if self.get_sizeof(text) < min_length:
|
251 |
+
break
|
252 |
+
|
253 |
+
|
254 |
+
# post-process
|
255 |
+
post_info = {}
|
256 |
+
for value in info.values():
|
257 |
+
key = value['label']
|
258 |
+
if key in post_info:
|
259 |
+
post_info[key].extend(value['text_keys'])
|
260 |
+
else:
|
261 |
+
post_info[key] = list(value['text_keys'])
|
262 |
+
|
263 |
+
# join sorted the text from list of keys
|
264 |
+
for key in post_info:
|
265 |
+
post_info[key] = ' '.join([x.split('_', 1)[1] for x in sorted(set(post_info[key]), key=self.custom_sort)])
|
266 |
+
|
267 |
+
|
268 |
+
return post_info
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
fasttext>=0.9.2
|
2 |
+
huggingface-hub>=0.14.1
|
3 |
+
numpy>=1.24.3,<2
|