caif / sampling.py
Балаганский Никита Николаевич
add app.py
030a0f8
raw
history blame
5.02 kB
import torch
from torch.nn import functional as F
import transformers
def sample_from_values(unscaled_probs, values):
samples = torch.multinomial(unscaled_probs, 1)
return torch.take_along_dim(values, samples, dim=1)
class TopKWithTemperatureSampler:
def __call__(self, input_ids, output_logits, top_k, temperature, **kwargs):
next_token_logits = output_logits[:, -1]
next_token_log_probs = F.log_softmax(
next_token_logits, dim=-1
)
topk_log_probs = next_token_log_probs.topk(top_k, -1)
next_tokens = sample_from_values(
torch.exp(topk_log_probs[0] / temperature), topk_log_probs[1]
).squeeze(1)
return next_tokens
class CAIFSampler:
def __init__(self, classifier_name, lm_tokenizer, device, invert_cls_probs: bool = False):
self.device = device
self.classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(
classifier_name
)
self.classifier_model = (
transformers.AutoModelForSequenceClassification.from_pretrained(
classifier_name
).to(device)
)
self.classifier_model.eval()
self.lm_tokenizer = lm_tokenizer
self.invert_cls_probs = invert_cls_probs
def __call__(
self,
input_ids,
output_logis,
top_k,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=None,
**kwargs
):
next_token_logits = output_logis[:, -1]
next_token_log_probs = F.log_softmax(
next_token_logits, dim=-1
)
(next_token_unnormalized_probs, topk_indices,) = self.get_unnormalized_probs(
input_ids,
next_token_log_probs,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=caif_tokens_num
)
topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
next_tokens = sample_from_values(
topk_probs[0],
torch.take_along_dim(topk_indices, topk_probs[1], dim=1),
).squeeze(1)
return next_tokens
def get_unnormalized_probs(
self,
input_ids,
next_token_log_probs,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=None
):
if classifier_weight == 0.0:
raise ValueError(
"classifier weight equal to 0 is not supported for CAIF Sampling"
)
top_next_token_log_probs = next_token_log_probs.topk(top_k_classifier, -1)
classifier_input = torch.cat(
[
input_ids.unsqueeze(1).repeat(1, top_k_classifier, 1).flatten(0, 1),
top_next_token_log_probs[1].view(-1).unsqueeze(-1),
],
-1,
)
classifier_input = [
self.lm_tokenizer.decode(sequence, skip_special_tokens=True)
for sequence in classifier_input
]
if self.invert_cls_probs:
classifier_log_probs = torch.log(
1 - self.get_classifier_probs(
classifier_input, caif_tokens_num=caif_tokens_num
).view(-1, top_k_classifier)
)
else:
classifier_log_probs = self.get_classifier_log_probs(
classifier_input, caif_tokens_num=caif_tokens_num
).view(-1, top_k_classifier)
next_token_probs = torch.exp(
(top_next_token_log_probs[0] + classifier_weight * classifier_log_probs)
/ temperature
)
return next_token_probs, top_next_token_log_probs[1]
def get_classifier_log_probs(self, input, caif_tokens_num=None):
input_ids = self.classifier_tokenizer(
input, padding=True, return_tensors="pt"
).to(self.device)
if caif_tokens_num is not None:
input_ids["input_ids"] = input_ids["input_ids"][:, -caif_tokens_num:]
if "attention_mask" in input_ids.keys():
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
if "token_type_ids" in input_ids.keys():
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
return torch.log(torch.sigmoid(logits))
def get_classifier_probs(self, input, caif_tokens_num=None):
input_ids = self.classifier_tokenizer(
input, padding=True, return_tensors="pt"
).to(self.device)
if caif_tokens_num is not None:
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
if "attention_mask" in input_ids.keys():
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
return torch.sigmoid(logits)