File size: 3,191 Bytes
67eaf9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch.nn.functional as F
import pandas as pd

from dataloader import create_dataloader
from utils import *


def predict_segmentation(inp, model, device, batch_size=8):
    test_loader = create_dataloader(inp, batch_size)

    predictions = []
    for batch in test_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        p = F.sigmoid(model(**batch).logits).detach().cpu().numpy()
        predictions.append(p)

    return np.concatenate(predictions, axis=0)


def create_data(text, tokenizer, seq_len=512):
    tokens = tokenizer(text, add_special_tokens=False)
    _token_batches = {k: [pad_seq(x, seq_len) for x in batch_list(v, seq_len)]
                      for (k, v) in tokens.items()}
    n_batches = len(_token_batches['input_ids'])
    return [{k: v[i] for k, v in _token_batches.items()}
            for i in range(n_batches)]


def segment_tokens(notes, model, tokenizer, device, batch_size=8):
    predictions = {}
    for note in notes.itertuples():
        note_id = note.note_id
        raw_text = note.text.lower()

        inp = create_data(raw_text, tokenizer)
        pred_probs = predict_segmentation(inp, model, device, batch_size=batch_size)
        pred_probs = np.squeeze(pred_probs, -1)
        pred_probs = np.concatenate(pred_probs)

        predictions[note_id] = pred_probs

    return predictions


def segment(notes, model, tokenizer, device, thresh, batch_size=8):
    predictions = []

    predictions_prob_map = segment_tokens(notes, model, tokenizer, device, batch_size)

    for note in notes.itertuples():

        note_id = note.note_id
        raw_text = note.text

        decoded_text = tokenizer.decode(tokenizer.encode(raw_text, add_special_tokens=False))

        pred_probs = predictions_prob_map[note_id]

        _, pred_probs = align_decoded(raw_text, decoded_text, pred_probs)
        pred_probs = np.array(pred_probs, 'float32')
        pred = (pred_probs > thresh).astype('uint8')

        spans = get_sequential_spans(pred)

        note_predictions = {'note_id': [], 'start': [], 'end': [], 'mention': [], 'score': []}
        for (start, end) in spans:
            note_predictions['note_id'].append(note_id)
            note_predictions['score'].append(pred_probs[start:end].mean())
            note_predictions['start'].append(start)
            note_predictions['end'].append(end)
            note_predictions['mention'].append(raw_text[start:end])

        note_predictions = pd.DataFrame(note_predictions)
        note_predictions = note_predictions.sort_values('score', ascending=False)

        # remove overlapping spans
        seen_spans = set()
        unseen = []
        for span in note_predictions[['start', 'end']].values:
            span = tuple(span)
            s = False
            if not is_overlap(seen_spans, span):
                seen_spans.add(span)
                s = True
            unseen.append(s)
        note_predictions = note_predictions[unseen]

        predictions.append(note_predictions)
    predictions = pd.concat(predictions).reset_index(drop=True)
    return predictions