Spaces:
Runtime error
Runtime error
freemt
commited on
Commit
•
a15cd26
1
Parent(s):
5a186f5
Update sent-ali fast and slow
Browse files- img/plt.png +0 -0
- radiobee/__main__.py +53 -6
- radiobee/align_sents.py +0 -72
- radiobee/align_sents.pyc +0 -0
- radiobee/error_msg.py +2 -2
- radiobee/gradiobee.py +68 -5
- radiobee/paras2sents.py +110 -0
- requirements.txt +1 -0
- tests/test_paras2sents.py +8 -2
img/plt.png
CHANGED
radiobee/__main__.py
CHANGED
@@ -139,7 +139,6 @@ if __name__ == "__main__":
|
|
139 |
gr.inputs.File(label="file 2", optional=True),
|
140 |
]
|
141 |
|
142 |
-
# modi 1
|
143 |
_ = """
|
144 |
tf_type: Literal[linear, sqrt, log, binary] = 'linear'
|
145 |
idf_type: Optional[Literal[standard, smooth, bm25]] = None
|
@@ -159,10 +158,13 @@ if __name__ == "__main__":
|
|
159 |
) # ditto
|
160 |
input_norm_type = gr.inputs.Radio(["None", "l1", "l2"], default="None") # ditto
|
161 |
|
162 |
-
inputs
|
|
|
|
|
|
|
163 |
gr.inputs.File(label="file 1"),
|
164 |
gr.inputs.File(label="file 2", optional=True),
|
165 |
-
input_tf_type, # modi inputs
|
166 |
input_idf_type,
|
167 |
input_dl_type,
|
168 |
input_norm_type,
|
@@ -178,6 +180,7 @@ if __name__ == "__main__":
|
|
178 |
step=1,
|
179 |
default=6,
|
180 |
),
|
|
|
181 |
]
|
182 |
|
183 |
examples = [
|
@@ -190,6 +193,29 @@ if __name__ == "__main__":
|
|
190 |
"None",
|
191 |
10,
|
192 |
6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
],
|
194 |
[
|
195 |
"data/test_en.txt",
|
@@ -200,6 +226,7 @@ if __name__ == "__main__":
|
|
200 |
"None",
|
201 |
10,
|
202 |
6,
|
|
|
203 |
],
|
204 |
[
|
205 |
"data/shakespeare_zh500.txt",
|
@@ -210,6 +237,7 @@ if __name__ == "__main__":
|
|
210 |
"None",
|
211 |
10,
|
212 |
6,
|
|
|
213 |
],
|
214 |
[
|
215 |
"data/shakespeare_en500.txt",
|
@@ -220,6 +248,7 @@ if __name__ == "__main__":
|
|
220 |
"None",
|
221 |
10,
|
222 |
6,
|
|
|
223 |
],
|
224 |
[
|
225 |
"data/hlm-ch1-zh.txt",
|
@@ -230,6 +259,7 @@ if __name__ == "__main__":
|
|
230 |
"None",
|
231 |
10,
|
232 |
6,
|
|
|
233 |
],
|
234 |
[
|
235 |
"data/hlm-ch1-en.txt",
|
@@ -240,6 +270,7 @@ if __name__ == "__main__":
|
|
240 |
"None",
|
241 |
10,
|
242 |
6,
|
|
|
243 |
],
|
244 |
[
|
245 |
"data/ps-cn.txt",
|
@@ -250,6 +281,7 @@ if __name__ == "__main__":
|
|
250 |
"None",
|
251 |
10,
|
252 |
4,
|
|
|
253 |
],
|
254 |
[
|
255 |
"data/test-dual.txt",
|
@@ -260,6 +292,7 @@ if __name__ == "__main__":
|
|
260 |
"None",
|
261 |
10,
|
262 |
6,
|
|
|
263 |
],
|
264 |
[
|
265 |
"data/英译中国现代散文选1(汉外对照丛书).txt",
|
@@ -270,6 +303,7 @@ if __name__ == "__main__":
|
|
270 |
"None",
|
271 |
10,
|
272 |
6,
|
|
|
273 |
],
|
274 |
[
|
275 |
"data/test-zh-ja.txt",
|
@@ -280,6 +314,7 @@ if __name__ == "__main__":
|
|
280 |
"None",
|
281 |
10,
|
282 |
6,
|
|
|
283 |
],
|
284 |
[
|
285 |
"data/xiyouji-ch1-zh.txt",
|
@@ -290,6 +325,7 @@ if __name__ == "__main__":
|
|
290 |
"None",
|
291 |
10,
|
292 |
6,
|
|
|
293 |
],
|
294 |
[
|
295 |
"data/demian-hesse-de.txt",
|
@@ -300,6 +336,7 @@ if __name__ == "__main__":
|
|
300 |
"None",
|
301 |
10,
|
302 |
6,
|
|
|
303 |
],
|
304 |
[
|
305 |
"data/catcher-in-the-rye-shixianrong-zh.txt",
|
@@ -310,6 +347,7 @@ if __name__ == "__main__":
|
|
310 |
"None",
|
311 |
10,
|
312 |
6,
|
|
|
313 |
],
|
314 |
]
|
315 |
|
@@ -340,14 +378,23 @@ if __name__ == "__main__":
|
|
340 |
out_file_dl_excel = gr.outputs.File(
|
341 |
label="Click to download xlsx",
|
342 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
-
# modi outputs
|
345 |
-
outputs = [
|
346 |
out_df,
|
347 |
-
# "plot",
|
348 |
gr.outputs.Image(label="plot"),
|
349 |
out_file_dl,
|
350 |
out_file_dl_excel,
|
|
|
|
|
351 |
out_df_aligned,
|
352 |
gr.outputs.HTML(),
|
353 |
]
|
|
|
139 |
gr.inputs.File(label="file 2", optional=True),
|
140 |
]
|
141 |
|
|
|
142 |
_ = """
|
143 |
tf_type: Literal[linear, sqrt, log, binary] = 'linear'
|
144 |
idf_type: Optional[Literal[standard, smooth, bm25]] = None
|
|
|
158 |
) # ditto
|
159 |
input_norm_type = gr.inputs.Radio(["None", "l1", "l2"], default="None") # ditto
|
160 |
|
161 |
+
# modi inputs 1, definitions
|
162 |
+
sent_ali_algo = gr.inputs.Radio(["None", "fast", "slow"], default="None")
|
163 |
+
|
164 |
+
inputs = [ # tot. 9, meed to modify input of gradio & examples
|
165 |
gr.inputs.File(label="file 1"),
|
166 |
gr.inputs.File(label="file 2", optional=True),
|
167 |
+
input_tf_type, # modi inputs 2
|
168 |
input_idf_type,
|
169 |
input_dl_type,
|
170 |
input_norm_type,
|
|
|
180 |
step=1,
|
181 |
default=6,
|
182 |
),
|
183 |
+
sent_ali_algo,
|
184 |
]
|
185 |
|
186 |
examples = [
|
|
|
193 |
"None",
|
194 |
10,
|
195 |
6,
|
196 |
+
"None",
|
197 |
+
],
|
198 |
+
[
|
199 |
+
"data/test_zh.txt",
|
200 |
+
"data/test_en.txt",
|
201 |
+
"linear",
|
202 |
+
"None",
|
203 |
+
"None",
|
204 |
+
"None",
|
205 |
+
10,
|
206 |
+
6,
|
207 |
+
"fast",
|
208 |
+
],
|
209 |
+
[
|
210 |
+
"data/test_zh.txt",
|
211 |
+
"data/test_en.txt",
|
212 |
+
"linear",
|
213 |
+
"None",
|
214 |
+
"None",
|
215 |
+
"None",
|
216 |
+
10,
|
217 |
+
6,
|
218 |
+
"slow",
|
219 |
],
|
220 |
[
|
221 |
"data/test_en.txt",
|
|
|
226 |
"None",
|
227 |
10,
|
228 |
6,
|
229 |
+
"None",
|
230 |
],
|
231 |
[
|
232 |
"data/shakespeare_zh500.txt",
|
|
|
237 |
"None",
|
238 |
10,
|
239 |
6,
|
240 |
+
"None",
|
241 |
],
|
242 |
[
|
243 |
"data/shakespeare_en500.txt",
|
|
|
248 |
"None",
|
249 |
10,
|
250 |
6,
|
251 |
+
"None",
|
252 |
],
|
253 |
[
|
254 |
"data/hlm-ch1-zh.txt",
|
|
|
259 |
"None",
|
260 |
10,
|
261 |
6,
|
262 |
+
"None",
|
263 |
],
|
264 |
[
|
265 |
"data/hlm-ch1-en.txt",
|
|
|
270 |
"None",
|
271 |
10,
|
272 |
6,
|
273 |
+
"None",
|
274 |
],
|
275 |
[
|
276 |
"data/ps-cn.txt",
|
|
|
281 |
"None",
|
282 |
10,
|
283 |
4,
|
284 |
+
"None",
|
285 |
],
|
286 |
[
|
287 |
"data/test-dual.txt",
|
|
|
292 |
"None",
|
293 |
10,
|
294 |
6,
|
295 |
+
"None",
|
296 |
],
|
297 |
[
|
298 |
"data/英译中国现代散文选1(汉外对照丛书).txt",
|
|
|
303 |
"None",
|
304 |
10,
|
305 |
6,
|
306 |
+
"None",
|
307 |
],
|
308 |
[
|
309 |
"data/test-zh-ja.txt",
|
|
|
314 |
"None",
|
315 |
10,
|
316 |
6,
|
317 |
+
"None",
|
318 |
],
|
319 |
[
|
320 |
"data/xiyouji-ch1-zh.txt",
|
|
|
325 |
"None",
|
326 |
10,
|
327 |
6,
|
328 |
+
"None",
|
329 |
],
|
330 |
[
|
331 |
"data/demian-hesse-de.txt",
|
|
|
336 |
"None",
|
337 |
10,
|
338 |
6,
|
339 |
+
"None",
|
340 |
],
|
341 |
[
|
342 |
"data/catcher-in-the-rye-shixianrong-zh.txt",
|
|
|
347 |
"None",
|
348 |
10,
|
349 |
6,
|
350 |
+
"None",
|
351 |
],
|
352 |
]
|
353 |
|
|
|
378 |
out_file_dl_excel = gr.outputs.File(
|
379 |
label="Click to download xlsx",
|
380 |
)
|
381 |
+
out_sents_dl = gr.outputs.File(
|
382 |
+
label="Click to download sents csv",
|
383 |
+
)
|
384 |
+
out_sents_dl_excel = gr.outputs.File(
|
385 |
+
label="Click to download sents xlsx",
|
386 |
+
)
|
387 |
+
|
388 |
+
# modi outputs 1, definitions
|
389 |
|
390 |
+
# modi outputs 2, need to modify gradio error_msg
|
391 |
+
outputs = [ # tot. 8
|
392 |
out_df,
|
|
|
393 |
gr.outputs.Image(label="plot"),
|
394 |
out_file_dl,
|
395 |
out_file_dl_excel,
|
396 |
+
out_sents_dl,
|
397 |
+
out_sents_dl_excel,
|
398 |
out_df_aligned,
|
399 |
gr.outputs.HTML(),
|
400 |
]
|
radiobee/align_sents.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
"""Align sents via gale-church."""
|
2 |
-
# pylint: disable=invalid-name
|
3 |
-
|
4 |
-
from typing import List, Tuple # noqa
|
5 |
-
|
6 |
-
import re
|
7 |
-
|
8 |
-
# from itertools import tee
|
9 |
-
# from more_itertools import ilen
|
10 |
-
from nltk.translate.gale_church import align_blocks
|
11 |
-
|
12 |
-
from radiobee.amend_avec import amend_avec
|
13 |
-
|
14 |
-
|
15 |
-
def align_sents(lst1: List[str], lst2: List[str]) -> List[Tuple[str, str]]:
|
16 |
-
"""Align sents.
|
17 |
-
|
18 |
-
>>> lst1, lst2 = ['a', 'bs',], ['aaa', '34', 'a', 'b']
|
19 |
-
"""
|
20 |
-
if isinstance(lst1, str):
|
21 |
-
lst1 = [lst1]
|
22 |
-
|
23 |
-
if isinstance(lst2, str):
|
24 |
-
lst2 = [lst2]
|
25 |
-
|
26 |
-
src_blocks = [len(re.sub(r"\s+", "", elm)) for elm in lst1]
|
27 |
-
tgt_blocks = [len(re.sub(r"\s+", "", elm)) for elm in lst2]
|
28 |
-
|
29 |
-
avec = align_blocks(src_blocks, tgt_blocks)
|
30 |
-
|
31 |
-
len1, len2 = len(lst1), len(lst2)
|
32 |
-
# lst1, _ = tee(lst1)
|
33 |
-
# len1 = ilen(_)
|
34 |
-
# lst2, _ = tee(lst2)
|
35 |
-
# len2 = ilen(_)
|
36 |
-
|
37 |
-
amended_avec = amend_avec(avec, len1, len2)
|
38 |
-
|
39 |
-
texts = []
|
40 |
-
# for elm in aset:
|
41 |
-
# for elm0, elm1 in amended_avec:
|
42 |
-
for elm in amended_avec:
|
43 |
-
# elm0, elm1, elm2 = elm
|
44 |
-
elm0, elm1 = elm[:2]
|
45 |
-
_ = []
|
46 |
-
|
47 |
-
# src_text first
|
48 |
-
if isinstance(elm0, str):
|
49 |
-
_.append("")
|
50 |
-
else:
|
51 |
-
# _.append(src_text[int(elm0)])
|
52 |
-
_.append(lst1[int(elm0)])
|
53 |
-
|
54 |
-
if isinstance(elm1, str):
|
55 |
-
_.append("")
|
56 |
-
else:
|
57 |
-
# _.append(tgt_text[int(elm0)])
|
58 |
-
_.append(lst2[int(elm1)])
|
59 |
-
|
60 |
-
_a = """
|
61 |
-
if isinstance(elm2, str):
|
62 |
-
_.append("")
|
63 |
-
else:
|
64 |
-
_.append(round(elm2, 2))
|
65 |
-
# """
|
66 |
-
del _a
|
67 |
-
|
68 |
-
texts.append(tuple(_))
|
69 |
-
|
70 |
-
return texts
|
71 |
-
|
72 |
-
# return ["", ""]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
radiobee/align_sents.pyc
ADDED
Binary file (1.55 kB). View file
|
|
radiobee/error_msg.py
CHANGED
@@ -8,7 +8,7 @@ import pandas as pd
|
|
8 |
def error_msg(
|
9 |
msg: Optional[Union[str, Exception]],
|
10 |
title: str = "error message",
|
11 |
-
) -> Tuple[Union[pd.DataFrame, None], None, None, None, None, None]:
|
12 |
"""Prepare an error message for gradiobee outputs."""
|
13 |
if msg is None:
|
14 |
msg = "none..."
|
@@ -21,4 +21,4 @@ def error_msg(
|
|
21 |
df = pd.DataFrame([msg], columns=[title])
|
22 |
|
23 |
# return df, *((None,) * 4) # pyright complains
|
24 |
-
return df, None, None, None, None, None
|
|
|
8 |
def error_msg(
|
9 |
msg: Optional[Union[str, Exception]],
|
10 |
title: str = "error message",
|
11 |
+
) -> Tuple[Union[pd.DataFrame, None], None, None, None, None, None, None, None]:
|
12 |
"""Prepare an error message for gradiobee outputs."""
|
13 |
if msg is None:
|
14 |
msg = "none..."
|
|
|
21 |
df = pd.DataFrame([msg], columns=[title])
|
22 |
|
23 |
# return df, *((None,) * 4) # pyright complains
|
24 |
+
return df, None, None, None, None, None, None, None
|
radiobee/gradiobee.py
CHANGED
@@ -30,6 +30,10 @@ from radiobee.trim_df import trim_df
|
|
30 |
from radiobee.error_msg import error_msg
|
31 |
from radiobee.text2lists import text2lists
|
32 |
|
|
|
|
|
|
|
|
|
33 |
uname = platform.uname()
|
34 |
HFSPACES = False
|
35 |
if "amzn2" in uname.release: # on hf spaces
|
@@ -43,7 +47,7 @@ debug = False
|
|
43 |
debug = True
|
44 |
|
45 |
|
46 |
-
def gradiobee(
|
47 |
file1,
|
48 |
file2,
|
49 |
tf_type,
|
@@ -53,6 +57,7 @@ def gradiobee(
|
|
53 |
eps,
|
54 |
min_samples,
|
55 |
# debug=False,
|
|
|
56 |
):
|
57 |
"""Process inputs and return outputs."""
|
58 |
logger.debug(" *debug* ")
|
@@ -382,7 +387,7 @@ def gradiobee(
|
|
382 |
df_aligned = df_aligned[["text2", "text1", "likelihood"]]
|
383 |
df_aligned.columns = ["text1", "text2", "likelihood"]
|
384 |
|
385 |
-
ic(df_aligned.head())
|
386 |
|
387 |
# round the last column to 2
|
388 |
# df_aligned.likelihood = df_aligned.likelihood.round(2)
|
@@ -434,8 +439,66 @@ def gradiobee(
|
|
434 |
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, df_aligned
|
435 |
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, styled, df_html # gradio cant handle style
|
436 |
|
437 |
-
ic("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
-
|
|
|
440 |
|
441 |
-
#
|
|
|
|
30 |
from radiobee.error_msg import error_msg
|
31 |
from radiobee.text2lists import text2lists
|
32 |
|
33 |
+
from radiobee.align_sents import align_sents
|
34 |
+
from radiobee.shuffle_sents import shuffle_sents # type: ignore
|
35 |
+
from radiobee.paras2sents import paras2sents # type: ignore
|
36 |
+
|
37 |
uname = platform.uname()
|
38 |
HFSPACES = False
|
39 |
if "amzn2" in uname.release: # on hf spaces
|
|
|
47 |
debug = True
|
48 |
|
49 |
|
50 |
+
def gradiobee( # noqa
|
51 |
file1,
|
52 |
file2,
|
53 |
tf_type,
|
|
|
57 |
eps,
|
58 |
min_samples,
|
59 |
# debug=False,
|
60 |
+
sent_ali_algo,
|
61 |
):
|
62 |
"""Process inputs and return outputs."""
|
63 |
logger.debug(" *debug* ")
|
|
|
387 |
df_aligned = df_aligned[["text2", "text1", "likelihood"]]
|
388 |
df_aligned.columns = ["text1", "text2", "likelihood"]
|
389 |
|
390 |
+
ic("paras aligned: ", df_aligned.head(10))
|
391 |
|
392 |
# round the last column to 2
|
393 |
# df_aligned.likelihood = df_aligned.likelihood.round(2)
|
|
|
439 |
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, df_aligned
|
440 |
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, styled, df_html # gradio cant handle style
|
441 |
|
442 |
+
ic("sent-ali-algo: ", sent_ali_algo)
|
443 |
+
|
444 |
+
# ### sent-ali-algo is None: para align
|
445 |
+
if sent_ali_algo in ["None"]:
|
446 |
+
ic("returning para-ali outputs")
|
447 |
+
return df_trimmed, output_plot, file_dl, file_dl_xlsx, None, None, df_aligned, df_html
|
448 |
+
|
449 |
+
# ### proceed with sent align
|
450 |
+
if sent_ali_algo in ["fast"]:
|
451 |
+
ic(sent_ali_algo)
|
452 |
+
align_func = align_sents
|
453 |
+
|
454 |
+
ic(df_aligned.shape, df_aligned.columns)
|
455 |
+
|
456 |
+
aligned_sents = paras2sents(df_aligned, align_func)
|
457 |
+
|
458 |
+
# ic(pd.DataFrame(aligned_sents).shape, aligned_sents)
|
459 |
+
ic(pd.DataFrame(aligned_sents).shape)
|
460 |
+
|
461 |
+
df_aligned_sents = pd.DataFrame(aligned_sents, columns=["text1", "text2"])
|
462 |
+
else: # ["slow"]
|
463 |
+
ic(sent_ali_algo)
|
464 |
+
align_func = shuffle_sents
|
465 |
+
aligned_sents = paras2sents(df_aligned, align_func, lang1, lang2)
|
466 |
+
|
467 |
+
# add extra entry if necessary
|
468 |
+
aligned_sents = [list(sent) + [""] if len(sent) == 2 else list(sent) for sent in aligned_sents]
|
469 |
+
|
470 |
+
df_aligned_sents = pd.DataFrame(aligned_sents, columns=["text1", "text2", "likelihood"])
|
471 |
+
|
472 |
+
# prepare sents downloads
|
473 |
+
file_dl_sents = Path(f"{file_dl.stem}-sents{file_dl.suffix}")
|
474 |
+
file_dl_xlsx_sents = Path(f"{file_dl_xlsx.stem}-sents{file_dl_xlsx.suffix}")
|
475 |
+
_ = df_aligned_sents.to_csv(index=False)
|
476 |
+
file_dl_sents.write_text(_, encoding="utf8")
|
477 |
+
|
478 |
+
df_aligned_sents.to_excel(file_dl_xlsx_sents)
|
479 |
+
|
480 |
+
# prepare html output
|
481 |
+
if len(df_aligned_sents) > 200:
|
482 |
+
df_html = None
|
483 |
+
else: # show a one-bathc table in html
|
484 |
+
# style
|
485 |
+
styled = df_aligned_sents.style.set_properties(
|
486 |
+
**{
|
487 |
+
"font-size": "10pt",
|
488 |
+
"border-color": "black",
|
489 |
+
"border": "1px black solid !important"
|
490 |
+
}
|
491 |
+
# border-color="black",
|
492 |
+
).set_table_styles([{
|
493 |
+
"selector": "", # noqs
|
494 |
+
"props": [("border", "2px black solid !important")]}] # noqs
|
495 |
+
).format(
|
496 |
+
precision=2
|
497 |
+
)
|
498 |
+
df_html = styled.to_html()
|
499 |
|
500 |
+
# aligned sents outputs
|
501 |
+
ic("aligned sents outputs")
|
502 |
|
503 |
+
# return df_trimmed, output_plot, file_dl, file_dl_xlsx, None, None, df_aligned, df_html
|
504 |
+
return df_trimmed, output_plot, file_dl, file_dl_xlsx, file_dl_sents, file_dl_xlsx_sents, df_aligned_sents, df_html
|
radiobee/paras2sents.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Convert paras to sents."""
|
2 |
+
# pylint: disable=unused-import, too-many-branches, ungrouped-imports
|
3 |
+
|
4 |
+
from typing import Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
from itertools import zip_longest
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from logzero import logger
|
10 |
+
|
11 |
+
from radiobee.align_sents import align_sents
|
12 |
+
from radiobee.seg_text import seg_text
|
13 |
+
from radiobee.detect import detect
|
14 |
+
|
15 |
+
try:
|
16 |
+
from radiobee.shuffle_sents import shuffle_sents
|
17 |
+
except Exception as exc:
|
18 |
+
logger.error("shuffle_sents not available: %s, using align_sents", exc)
|
19 |
+
shuffle_sents = lambda x1, x2, lang1="", lang2="": align_sents(x1, x2) # noqa
|
20 |
+
|
21 |
+
|
22 |
+
def paras2sents(
|
23 |
+
paras_: Union[pd.DataFrame, List[Tuple[str, str, Union[str, float]]], np.ndarray],
|
24 |
+
align_func: Optional[Union[Callable, str]] = None,
|
25 |
+
lang1: Optional[str] = None,
|
26 |
+
lang2: Optional[str] = None,
|
27 |
+
) -> List[Tuple[str, str, Union[str, float]]]:
|
28 |
+
"""Convert paras to sents using align_func.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
paras_: list of 3-tuples or numpy or pd.DataFrame
|
32 |
+
lang1: fisrt lang code
|
33 |
+
lang2: second lang code
|
34 |
+
align_func: func used in the sent level
|
35 |
+
if set to None, default to align_sents
|
36 |
+
Returns:
|
37 |
+
list of sents (possible with likelihood for shuffle_sents)
|
38 |
+
"""
|
39 |
+
# wrap everything in pd.DataFrame
|
40 |
+
# necessary to make pyright happy
|
41 |
+
paras = pd.DataFrame(paras_).fillna("")
|
42 |
+
|
43 |
+
# take the first three columns at maximum
|
44 |
+
paras = paras.iloc[:, :3]
|
45 |
+
|
46 |
+
if len(paras.columns) < 2:
|
47 |
+
logger.error(
|
48 |
+
"Need at least two columns, got %s",
|
49 |
+
len(paras.columns)
|
50 |
+
)
|
51 |
+
raise Exception("wrong data")
|
52 |
+
|
53 |
+
# append the third col (all "") if there are only two cols
|
54 |
+
if len(paras.columns) < 3:
|
55 |
+
paras.insert(2, "likelihood", [""] * len(paras))
|
56 |
+
|
57 |
+
if lang1 is None:
|
58 |
+
lang1 = detect(" ".join(paras.iloc[:, 0]))
|
59 |
+
if lang2 is None:
|
60 |
+
lang2 = detect(" ".join(paras.iloc[:, 1]))
|
61 |
+
|
62 |
+
left, right = [], []
|
63 |
+
row0, row1 = [], []
|
64 |
+
for elm0, elm1, elm2 in paras.values:
|
65 |
+
sents0 = seg_text(elm0, lang1)
|
66 |
+
sents1 = seg_text(elm1, lang2)
|
67 |
+
if isinstance(elm2, float) and elm2 > 0:
|
68 |
+
if row0 or row1:
|
69 |
+
left.append(row0)
|
70 |
+
right.append(row1)
|
71 |
+
row0, row1 = [], [] # collect and prepare
|
72 |
+
|
73 |
+
if sents0:
|
74 |
+
left.append(sents0)
|
75 |
+
if sents1:
|
76 |
+
right.append(sents1)
|
77 |
+
else:
|
78 |
+
if sents0:
|
79 |
+
row0.extend(sents0)
|
80 |
+
if sents1:
|
81 |
+
row1.extend(sents1)
|
82 |
+
# collect possible last batch
|
83 |
+
if row0 or row1:
|
84 |
+
left.append(row0)
|
85 |
+
right.append(row1)
|
86 |
+
|
87 |
+
# res = [*zip(left, right)]
|
88 |
+
|
89 |
+
# align each batch using align_func
|
90 |
+
|
91 |
+
# ready align_func
|
92 |
+
if align_func is None:
|
93 |
+
align_func = align_sents
|
94 |
+
if isinstance(align_func, str) and align_func.startswith("shuffle") or not isinstance(align_func, str) and align_func.__name__ in ["shuffle_sents"]:
|
95 |
+
align_func = lambda row0, row1: shuffle_sents(row0, row1, lang1=lang1, lang2=lang2) # noqa
|
96 |
+
else:
|
97 |
+
align_func = align_sents
|
98 |
+
|
99 |
+
res = []
|
100 |
+
for row0, row1 in zip(left, right):
|
101 |
+
try:
|
102 |
+
_ = align_func(row0, row1)
|
103 |
+
except Exception as exc:
|
104 |
+
logger.info("probably empty para supplied: %s, resorting to zip_longest", exc)
|
105 |
+
_ = [*zip_longest(row0, row1, fillvalue="")]
|
106 |
+
|
107 |
+
# res.append(_)
|
108 |
+
res.extend(_)
|
109 |
+
|
110 |
+
return res
|
requirements.txt
CHANGED
@@ -23,6 +23,7 @@ pyicu
|
|
23 |
pycld2
|
24 |
tqdm
|
25 |
polyglot
|
|
|
26 |
sentence_splitter
|
27 |
icecream
|
28 |
# lazy
|
|
|
23 |
pycld2
|
24 |
tqdm
|
25 |
polyglot
|
26 |
+
nltk
|
27 |
sentence_splitter
|
28 |
icecream
|
29 |
# lazy
|
tests/test_paras2sents.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""Test paras2sents."""
|
2 |
# pylint: disable=invalid-name
|
3 |
|
|
|
4 |
import pandas as pd
|
5 |
from radiobee.paras2sents import paras2sents
|
6 |
from radiobee.shuffle_sents import shuffle_sents
|
@@ -14,15 +15,20 @@ def test_paras2sents_dual():
|
|
14 |
"""Test paras2sents_dual."""
|
15 |
sents = paras2sents(paras)
|
16 |
|
|
|
|
|
17 |
assert len(sents) > 202 # 208
|
18 |
# assert not sents
|
19 |
|
20 |
|
21 |
def test_paras2sents_dual_model_s():
|
22 |
"""Test paras2sents_dual_model_s."""
|
23 |
-
|
|
|
|
|
|
|
24 |
|
25 |
-
assert len(
|
26 |
# assert not sents
|
27 |
|
28 |
|
|
|
1 |
"""Test paras2sents."""
|
2 |
# pylint: disable=invalid-name
|
3 |
|
4 |
+
import numpy as np
|
5 |
import pandas as pd
|
6 |
from radiobee.paras2sents import paras2sents
|
7 |
from radiobee.shuffle_sents import shuffle_sents
|
|
|
15 |
"""Test paras2sents_dual."""
|
16 |
sents = paras2sents(paras)
|
17 |
|
18 |
+
assert np.array(sents).shape.__len__() > 1
|
19 |
+
|
20 |
assert len(sents) > 202 # 208
|
21 |
# assert not sents
|
22 |
|
23 |
|
24 |
def test_paras2sents_dual_model_s():
|
25 |
"""Test paras2sents_dual_model_s."""
|
26 |
+
sents1 = paras2sents(paras, shuffle_sents)
|
27 |
+
|
28 |
+
# assert np.array(sents1).shape.__len__() > 1
|
29 |
+
assert pd.DataFrame(sents1).shape.__len__() > 1
|
30 |
|
31 |
+
assert len(sents1) > 201 # 207
|
32 |
# assert not sents
|
33 |
|
34 |
|