Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,352 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import torch
import numpy as np
from tqdm import tqdm
from util import extraction
def inference_sample(model, tok, request, tok_type='subject_final', return_logits=False):
""" Single token inference for a single sample
"""
if type(request)==str: request = {'prompt': '{}', 'subject': request}
all_prompts = [request["prompt"]]
# Compute indices of the tokens where the fact is looked up
lookup_idxs = [
extraction.find_token_index(
tok, prompt, request["subject"], tok_type, verbose=False
)
for i, prompt in enumerate(all_prompts)
]
input_tok = tok(
[prompt.format(request["subject"]) for prompt in all_prompts],
return_tensors="pt",
padding=True,
).to("cuda")
# inference
logits = model(**input_tok).logits
# original logits output
located_logit = logits[0][lookup_idxs[0]]
output_token = torch.argmax(located_logit)
output_decoded = tok.decode(output_token)
output_token = output_token.detach().cpu().item()
if return_logits:
return output_token, output_decoded, located_logit.detach().cpu().numpy()
return output_token, output_decoded
def perform_inference(
model,
tok,
requests,
additional_context=None,
verbose=1
):
output_tokens = []
if verbose == 0:
disable_tqdm = True
else:
disable_tqdm = False
for i in tqdm(range(len(requests)), disable=disable_tqdm):
request = requests[i]
if additional_context is not None:
request["prompt"] = additional_context.format(request['prompt'])
output_token, _ = inference_sample(model, tok, request)
output_tokens.append(output_token)
output_tokens = np.array(output_tokens)
return output_tokens
def inference_batch(
model,
tok,
all_subjects,
all_prompts,
batch_size=256,
additional_context = None,
return_logits = False,
disable_tqdms=False
):
from util import nethook
# find total number of batches
num_batches = int(np.ceil(len(all_prompts)/batch_size))
if type(all_subjects) == str:
all_subjects = [all_subjects]*len(all_prompts)
all_prompts = list(all_prompts)
all_subjects = list(all_subjects)
final_tokens = []
final_logits = []
if not disable_tqdms and (additional_context is not None):
print('Adding context: ', additional_context)
model.eval()
nethook.set_requires_grad(False, model)
with torch.no_grad():
for i in tqdm(range(num_batches), disable=disable_tqdms):
# find batch prompts and subjects
prompts = all_prompts[i*batch_size:(i+1)*batch_size]
subjects = all_subjects[i*batch_size:(i+1)*batch_size]
# add additional context if required
if additional_context is not None:
if '{}' in additional_context:
prompts = [additional_context.format(prompt) for prompt in prompts]
else:
prompts = [additional_context + prompt for prompt in prompts]
# embed text into tokens
input_tok = tok(
[prompt.format(subject) for prompt, subject in zip(prompts, subjects)],
return_tensors="pt",
padding=True,
).to("cuda")
# model inference for batch
logits = model(**input_tok).logits
logits = logits.detach().cpu().numpy()
# find first predicted token
indices = extraction.find_last_one_in_each_row(input_tok['attention_mask'].cpu().numpy()) #+ 1
# find final tokens
final_toks = [np.argmax(logits[i][indices[i]]) for i in range(len(indices))]
if return_logits:
final_ls = [logits[i][indices[i]] for i in range(len(indices))]
final_tokens = final_tokens + final_toks
if return_logits:
final_logits = final_logits + final_ls
del input_tok
del logits
final_tokens = np.array(final_tokens)
if return_logits:
final_logits = np.array(final_logits)
return final_tokens, final_logits
return final_tokens
|