File size: 4,352 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


import os 

import torch 
import numpy as np

from tqdm import tqdm

from util import extraction


def inference_sample(model, tok, request, tok_type='subject_final', return_logits=False):
    """ Single token inference for a single sample 
    """
    if type(request)==str: request = {'prompt': '{}', 'subject': request}

    all_prompts = [request["prompt"]]

    # Compute indices of the tokens where the fact is looked up
    lookup_idxs = [
        extraction.find_token_index(
            tok, prompt, request["subject"], tok_type, verbose=False
        )
        for i, prompt in enumerate(all_prompts)
    ]
    input_tok = tok(
        [prompt.format(request["subject"]) for prompt in all_prompts],
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    # inference
    logits = model(**input_tok).logits

    # original logits output
    located_logit = logits[0][lookup_idxs[0]]
        
    output_token = torch.argmax(located_logit)
    output_decoded = tok.decode(output_token)
    output_token = output_token.detach().cpu().item()

    if return_logits:
        return output_token, output_decoded, located_logit.detach().cpu().numpy()
    return output_token, output_decoded


def perform_inference(
        model, 
        tok, 
        requests, 
        additional_context=None, 
        verbose=1
    ):
    output_tokens = []

    if verbose == 0:
        disable_tqdm = True
    else:
        disable_tqdm = False

    for i in tqdm(range(len(requests)), disable=disable_tqdm):

        request = requests[i]

        if additional_context is not None:
            request["prompt"] = additional_context.format(request['prompt'])

        output_token, _ = inference_sample(model, tok, request)
        output_tokens.append(output_token)

    output_tokens = np.array(output_tokens)
    return output_tokens


def inference_batch(
        model, 
        tok, 
        all_subjects, 
        all_prompts, 
        batch_size=256, 
        additional_context = None, 
        return_logits = False,
        disable_tqdms=False
    ):
    from util import nethook

    # find total number of batches
    num_batches = int(np.ceil(len(all_prompts)/batch_size))

    if type(all_subjects) == str:
        all_subjects = [all_subjects]*len(all_prompts)

    all_prompts = list(all_prompts)
    all_subjects = list(all_subjects)

    final_tokens = []
    final_logits = []

    if not disable_tqdms and (additional_context is not None):
        print('Adding context: ', additional_context)

    model.eval()
    nethook.set_requires_grad(False, model)

    with torch.no_grad():
        for i in tqdm(range(num_batches), disable=disable_tqdms):

            # find batch prompts and subjects
            prompts = all_prompts[i*batch_size:(i+1)*batch_size]
            subjects = all_subjects[i*batch_size:(i+1)*batch_size]

            # add additional context if required
            if additional_context is not None:

                if '{}' in additional_context:
                    prompts = [additional_context.format(prompt) for prompt in prompts]
                else:
                    prompts = [additional_context + prompt for prompt in prompts]

                
            # embed text into tokens
            input_tok = tok(
                [prompt.format(subject) for prompt, subject in zip(prompts, subjects)],
                return_tensors="pt",
                padding=True,
            ).to("cuda")

            # model inference for batch
            logits = model(**input_tok).logits
            logits = logits.detach().cpu().numpy()

            # find first predicted token
            indices = extraction.find_last_one_in_each_row(input_tok['attention_mask'].cpu().numpy()) #+ 1

            # find final tokens
            final_toks = [np.argmax(logits[i][indices[i]]) for i in range(len(indices))]

            if return_logits:
                final_ls = [logits[i][indices[i]] for i in range(len(indices))]

            final_tokens = final_tokens + final_toks

            if return_logits:
                final_logits = final_logits + final_ls

            del input_tok
            del logits

    final_tokens = np.array(final_tokens)
    if return_logits:
        final_logits = np.array(final_logits)
        return final_tokens, final_logits
    return final_tokens