Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from util import extraction | |
def inference_sample(model, tok, request, tok_type='subject_final', return_logits=False): | |
""" Single token inference for a single sample | |
""" | |
if type(request)==str: request = {'prompt': '{}', 'subject': request} | |
all_prompts = [request["prompt"]] | |
# Compute indices of the tokens where the fact is looked up | |
lookup_idxs = [ | |
extraction.find_token_index( | |
tok, prompt, request["subject"], tok_type, verbose=False | |
) | |
for i, prompt in enumerate(all_prompts) | |
] | |
input_tok = tok( | |
[prompt.format(request["subject"]) for prompt in all_prompts], | |
return_tensors="pt", | |
padding=True, | |
).to("cuda") | |
# inference | |
logits = model(**input_tok).logits | |
# original logits output | |
located_logit = logits[0][lookup_idxs[0]] | |
output_token = torch.argmax(located_logit) | |
output_decoded = tok.decode(output_token) | |
output_token = output_token.detach().cpu().item() | |
if return_logits: | |
return output_token, output_decoded, located_logit.detach().cpu().numpy() | |
return output_token, output_decoded | |
def perform_inference( | |
model, | |
tok, | |
requests, | |
additional_context=None, | |
verbose=1 | |
): | |
output_tokens = [] | |
if verbose == 0: | |
disable_tqdm = True | |
else: | |
disable_tqdm = False | |
for i in tqdm(range(len(requests)), disable=disable_tqdm): | |
request = requests[i] | |
if additional_context is not None: | |
request["prompt"] = additional_context.format(request['prompt']) | |
output_token, _ = inference_sample(model, tok, request) | |
output_tokens.append(output_token) | |
output_tokens = np.array(output_tokens) | |
return output_tokens | |
def inference_batch( | |
model, | |
tok, | |
all_subjects, | |
all_prompts, | |
batch_size=256, | |
additional_context = None, | |
return_logits = False, | |
disable_tqdms=False | |
): | |
from util import nethook | |
# find total number of batches | |
num_batches = int(np.ceil(len(all_prompts)/batch_size)) | |
if type(all_subjects) == str: | |
all_subjects = [all_subjects]*len(all_prompts) | |
all_prompts = list(all_prompts) | |
all_subjects = list(all_subjects) | |
final_tokens = [] | |
final_logits = [] | |
if not disable_tqdms and (additional_context is not None): | |
print('Adding context: ', additional_context) | |
model.eval() | |
nethook.set_requires_grad(False, model) | |
with torch.no_grad(): | |
for i in tqdm(range(num_batches), disable=disable_tqdms): | |
# find batch prompts and subjects | |
prompts = all_prompts[i*batch_size:(i+1)*batch_size] | |
subjects = all_subjects[i*batch_size:(i+1)*batch_size] | |
# add additional context if required | |
if additional_context is not None: | |
if '{}' in additional_context: | |
prompts = [additional_context.format(prompt) for prompt in prompts] | |
else: | |
prompts = [additional_context + prompt for prompt in prompts] | |
# embed text into tokens | |
input_tok = tok( | |
[prompt.format(subject) for prompt, subject in zip(prompts, subjects)], | |
return_tensors="pt", | |
padding=True, | |
).to("cuda") | |
# model inference for batch | |
logits = model(**input_tok).logits | |
logits = logits.detach().cpu().numpy() | |
# find first predicted token | |
indices = extraction.find_last_one_in_each_row(input_tok['attention_mask'].cpu().numpy()) #+ 1 | |
# find final tokens | |
final_toks = [np.argmax(logits[i][indices[i]]) for i in range(len(indices))] | |
if return_logits: | |
final_ls = [logits[i][indices[i]] for i in range(len(indices))] | |
final_tokens = final_tokens + final_toks | |
if return_logits: | |
final_logits = final_logits + final_ls | |
del input_tok | |
del logits | |
final_tokens = np.array(final_tokens) | |
if return_logits: | |
final_logits = np.array(final_logits) | |
return final_tokens, final_logits | |
return final_tokens | |