|
from typing import List |
|
|
|
from functools import lru_cache |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
import transformers |
|
|
|
from utils import get_cls |
|
|
|
|
|
def sample_from_values(unscaled_probs, values): |
|
samples = torch.multinomial(unscaled_probs, 1) |
|
return torch.take_along_dim(values, samples, dim=1) |
|
|
|
|
|
class TopKWithTemperatureSampler: |
|
def __call__(self, input_ids, output_logits, top_k, temperature, **kwargs): |
|
|
|
next_token_logits = output_logits[:, -1] |
|
next_token_log_probs = F.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
topk_log_probs = next_token_log_probs.topk(top_k, -1) |
|
next_tokens = sample_from_values( |
|
torch.exp(topk_log_probs[0] / temperature), topk_log_probs[1] |
|
).squeeze(1) |
|
|
|
return next_tokens |
|
|
|
|
|
class CAIFSampler: |
|
@lru_cache(20) |
|
def __init__(self, classifier_name, lm_tokenizer, device, invert_cls_probs: bool = False): |
|
self.device = device |
|
self.classifier_tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
classifier_name |
|
) |
|
self.classifier_model = ( |
|
get_cls(classifier_name).to(device) |
|
) |
|
self.classifier_model.eval() |
|
self.lm_tokenizer = lm_tokenizer |
|
self.invert_cls_probs = invert_cls_probs |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
output_logis, |
|
top_k, |
|
temperature, |
|
top_k_classifier, |
|
classifier_weight, |
|
caif_tokens_num=None, |
|
act_type: str = "sigmoid", |
|
target_cls_id: int = 0, |
|
**kwargs |
|
): |
|
print(act_type) |
|
next_token_logits = output_logis[:, -1] |
|
next_token_log_probs = F.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
(next_token_unnormalized_probs, topk_indices,) = self.get_unnormalized_probs( |
|
input_ids, |
|
next_token_log_probs, |
|
temperature, |
|
top_k_classifier, |
|
classifier_weight, |
|
caif_tokens_num=caif_tokens_num, |
|
target_cls_id=target_cls_id |
|
) |
|
topk_probs = next_token_unnormalized_probs.topk(top_k, -1) |
|
next_tokens = sample_from_values( |
|
topk_probs[0], |
|
torch.take_along_dim(topk_indices, topk_probs[1], dim=1), |
|
).squeeze(1) |
|
|
|
return next_tokens |
|
|
|
def get_unnormalized_probs( |
|
self, |
|
input_ids, |
|
next_token_log_probs, |
|
temperature, |
|
top_k_classifier, |
|
classifier_weight, |
|
target_cls_id: int = 0, |
|
act_type: str = "sigmoid", |
|
caif_tokens_num=None |
|
): |
|
|
|
if classifier_weight == 0.0: |
|
raise ValueError( |
|
"classifier weight equal to 0 is not supported for CAIF Sampling" |
|
) |
|
|
|
top_next_token_log_probs = next_token_log_probs.topk(top_k_classifier, -1) |
|
classifier_input = torch.cat( |
|
[ |
|
input_ids.unsqueeze(1).repeat(1, top_k_classifier, 1).flatten(0, 1), |
|
top_next_token_log_probs[1].view(-1).unsqueeze(-1), |
|
], |
|
-1, |
|
) |
|
classifier_input = [ |
|
self.lm_tokenizer.decode(sequence, skip_special_tokens=True) |
|
for sequence in classifier_input |
|
] |
|
|
|
if self.invert_cls_probs: |
|
classifier_log_probs = torch.log( |
|
1 - self.get_classifier_probs( |
|
classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id |
|
).view(-1, top_k_classifier) |
|
) |
|
else: |
|
classifier_log_probs = self.get_classifier_log_probs( |
|
classifier_input, |
|
caif_tokens_num=caif_tokens_num, |
|
target_cls_id=target_cls_id, |
|
act_type=act_type, |
|
).view(-1, top_k_classifier) |
|
|
|
next_token_probs = torch.exp( |
|
(top_next_token_log_probs[0] + |
|
classifier_weight * (classifier_log_probs - classifier_log_probs.mean(-1)) - |
|
top_next_token_log_probs[0].mean(-1)) |
|
/ temperature |
|
) |
|
return next_token_probs, top_next_token_log_probs[1] |
|
|
|
def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0, act_type: str = "sigmoid"): |
|
input_ids = self.classifier_tokenizer( |
|
input, padding=True, return_tensors="pt" |
|
).to(self.device) |
|
if caif_tokens_num is not None: |
|
input_ids["input_ids"] = input_ids["input_ids"][:, -caif_tokens_num:] |
|
if "attention_mask" in input_ids.keys(): |
|
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:] |
|
if "token_type_ids" in input_ids.keys(): |
|
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:] |
|
|
|
if act_type == "sigmoid": |
|
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1) |
|
return F.logsigmoid(logits) |
|
if act_type == "softmax": |
|
logits = F.log_softmax(self.classifier_model(**input_ids).logits)[:, target_cls_id].squeeze(-1) |
|
return logits |
|
|
|
def get_classifier_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0): |
|
input_ids = self.classifier_tokenizer( |
|
input, padding=True, return_tensors="pt" |
|
).to(self.device) |
|
if caif_tokens_num is not None: |
|
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:] |
|
if "attention_mask" in input_ids.keys(): |
|
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:] |
|
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1) |
|
return torch.sigmoid(logits) |
|
|