File size: 1,861 Bytes
67eaf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np


def is_overlap(existing_spans, new_span):
    for span in existing_spans:
        # Check if either end of the new span is within an existing span
        if (span[0] <= new_span[0] <= span[1]) or \
                (span[0] <= new_span[1] <= span[1]):
            return True
        # Check if the new span entirely covers an existing span
        if new_span[0] <= span[0] and new_span[1] >= span[1]:
            return True
    return False


def get_sequential_spans(a):
    spans = []

    prev = False
    start = 0

    for i, x in enumerate(a):
        if not prev and x:
            start = i
        elif prev and not x:
            spans.append((start, i))

        prev = x

    if x:
        spans.append((start, i + 1))

    return spans


def batch_list(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]


def pad_seq(seq, max_len):
    n = len(seq)
    if n >= max_len:
        return seq
    else:
        return np.pad(seq, (0, max_len - n))


def align_decoded(x, d, y):
    clean_text = ""
    clean_label = []
    j = 0
    for i in range(len(d)):
        found = False
        for delim in [',', '.', '?', "'"]:
            if (x[j:j + 2] == f" {delim}") and (d[i] == f"{delim}"):
                found = True
                clean_text += f' {delim}'
                clean_label += [y[j], y[j]]
                j += 1

        if not found:
            clean_text += x[j]
            clean_label += [y[j]]
        j += 1

    if (clean_text != x) and (x[-1:] == "\n"):
        clean_text += "\n"
        clean_label += [0, 0]

    return clean_text, clean_label


def clean_entity(t):
    t = t.lower()
    t = t.replace(' \n', " ")
    t = t.replace('\n', " ")
    return t