|
import unicodedata |
|
from typing import List, Optional |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from .logit_lens import LogitLens |
|
|
|
|
|
def generate_interactive( |
|
model: AutoModelForCausalLM, |
|
tok: AutoTokenizer, |
|
top_k: int = 5, |
|
max_out_len: int = 200, |
|
compare_against: Optional[AutoModelForCausalLM] = None, |
|
use_logit_lens: bool = False, |
|
layer_module_tmp: str = "transformer.h.{}", |
|
ln_f_module: str = "transformer.ln_f", |
|
lm_head_module: str = "lm_head", |
|
): |
|
""" |
|
Puts generation in a loop. Allows users to repeatedly provide inputs |
|
with which text is generated. |
|
""" |
|
|
|
if use_logit_lens: |
|
llens_gen = LogitLens( |
|
model, |
|
tok, |
|
layer_module_tmp, |
|
ln_f_module, |
|
lm_head_module, |
|
disabled=not use_logit_lens, |
|
) |
|
if compare_against: |
|
llens_vanilla = LogitLens( |
|
compare_against, |
|
tok, |
|
layer_module_tmp, |
|
ln_f_module, |
|
lm_head_module, |
|
disabled=not use_logit_lens, |
|
) |
|
|
|
while True: |
|
prompt = input("Enter a prompt: ").strip(" \r\t\n") |
|
|
|
print( |
|
f"Argument Model: " |
|
f"{generate_fast(model, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}" |
|
) |
|
if compare_against: |
|
print( |
|
f"Baseline Model: " |
|
f"{generate_fast(compare_against, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}" |
|
) |
|
|
|
if use_logit_lens: |
|
inp_prompt = tok([prompt], padding=True, return_tensors="pt").to( |
|
next(model.parameters()).device |
|
) |
|
|
|
with llens_gen: |
|
model(**inp_prompt) |
|
print("\n--- Argument Model Logit Lens ---") |
|
llens_gen.pprint() |
|
|
|
if compare_against: |
|
with llens_vanilla: |
|
compare_against(**inp_prompt) |
|
print("--- Baseline Model Logit Lens ---") |
|
llens_vanilla.pprint() |
|
|
|
print() |
|
|
|
|
|
def generate_fast( |
|
model: AutoModelForCausalLM, |
|
tok: AutoTokenizer, |
|
prompts: List[str], |
|
n_gen_per_prompt: int = 1, |
|
top_k: int = 5, |
|
max_out_len: int = 200, |
|
vanilla_generation=False, |
|
): |
|
""" |
|
Fast, parallelized auto-regressive text generation with top-k sampling. |
|
Our custom implementation. |
|
""" |
|
|
|
|
|
inp = [prompt for prompt in prompts for _ in range(n_gen_per_prompt)] |
|
inp_tok = tok(inp, padding=True, return_tensors="pt").to( |
|
next(model.parameters()).device |
|
) |
|
input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"] |
|
if vanilla_generation: |
|
gen_txt = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_out_len |
|
) |
|
txt = [tok.decode(x, skip_special_tokens=True) for x in gen_txt.detach().cpu().numpy().tolist()] |
|
txt = [ |
|
unicodedata.normalize("NFKD", x) |
|
.replace("\n\n", " ") |
|
.replace("<|endoftext|>", "") |
|
for x in txt |
|
] |
|
return txt |
|
batch_size = input_ids.size(0) |
|
|
|
|
|
|
|
|
|
|
|
past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item()) |
|
|
|
with torch.no_grad(): |
|
while input_ids.size(1) < max_out_len: |
|
model_out = model( |
|
input_ids=input_ids[:, cur_context], |
|
attention_mask=None if 'llama'or'baichuan' in model.name_or_path.lower() else attention_mask[:, cur_context], |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
) |
|
logits, past_key_values = model_out.logits, model_out.past_key_values |
|
softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1) |
|
|
|
|
|
tk = torch.topk(softmax_out, top_k, dim=1).indices |
|
softmax_out_top_k = torch.gather(softmax_out, 1, tk) |
|
softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None] |
|
new_tok_indices = torch.multinomial(softmax_out_top_k, 1) |
|
new_toks = torch.gather(tk, 1, new_tok_indices) |
|
|
|
|
|
|
|
if cur_context.stop == input_ids.size(1): |
|
attention_mask = torch.cat( |
|
[attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1 |
|
) |
|
input_ids = torch.cat( |
|
[ |
|
input_ids, |
|
input_ids.new_ones(batch_size, 1) * tok.pad_token_id, |
|
], |
|
dim=1, |
|
) |
|
|
|
last_non_masked = attention_mask.sum(1) - 1 |
|
for i in range(batch_size): |
|
new_idx = last_non_masked[i] + 1 |
|
if last_non_masked[i].item() + 1 != cur_context.stop: |
|
continue |
|
|
|
|
|
if new_idx < max_out_len: |
|
input_ids[i][new_idx] = new_toks[i] |
|
attention_mask[i][new_idx] = 1 |
|
|
|
cur_context = slice(cur_context.stop, cur_context.stop + 1) |
|
txt = [tok.decode(x, skip_special_tokens=True) for x in input_ids.detach().cpu().numpy().tolist()] |
|
txt = [ |
|
unicodedata.normalize("NFKD", x) |
|
.replace("\n\n", " ") |
|
.replace("<|endoftext|>", "") |
|
for x in txt |
|
] |
|
|
|
return txt |
|
|