xlreator's picture
initial commit
67eaf9f verified
import numpy as np
def is_overlap(existing_spans, new_span):
for span in existing_spans:
# Check if either end of the new span is within an existing span
if (span[0] <= new_span[0] <= span[1]) or \
(span[0] <= new_span[1] <= span[1]):
return True
# Check if the new span entirely covers an existing span
if new_span[0] <= span[0] and new_span[1] >= span[1]:
return True
return False
def get_sequential_spans(a):
spans = []
prev = False
start = 0
for i, x in enumerate(a):
if not prev and x:
start = i
elif prev and not x:
spans.append((start, i))
prev = x
if x:
spans.append((start, i + 1))
return spans
def batch_list(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
def pad_seq(seq, max_len):
n = len(seq)
if n >= max_len:
return seq
else:
return np.pad(seq, (0, max_len - n))
def align_decoded(x, d, y):
clean_text = ""
clean_label = []
j = 0
for i in range(len(d)):
found = False
for delim in [',', '.', '?', "'"]:
if (x[j:j + 2] == f" {delim}") and (d[i] == f"{delim}"):
found = True
clean_text += f' {delim}'
clean_label += [y[j], y[j]]
j += 1
if not found:
clean_text += x[j]
clean_label += [y[j]]
j += 1
if (clean_text != x) and (x[-1:] == "\n"):
clean_text += "\n"
clean_label += [0, 0]
return clean_text, clean_label
def clean_entity(t):
t = t.lower()
t = t.replace(' \n', " ")
t = t.replace('\n', " ")
return t