File size: 5,238 Bytes
030a0f8 d320fdd 030a0f8 d320fdd 030a0f8 d320fdd 030a0f8 d320fdd 030a0f8 d320fdd 030a0f8 d320fdd 030a0f8 d320fdd 030a0f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
from torch.nn import functional as F
import transformers
def sample_from_values(unscaled_probs, values):
samples = torch.multinomial(unscaled_probs, 1)
return torch.take_along_dim(values, samples, dim=1)
class TopKWithTemperatureSampler:
def __call__(self, input_ids, output_logits, top_k, temperature, **kwargs):
next_token_logits = output_logits[:, -1]
next_token_log_probs = F.log_softmax(
next_token_logits, dim=-1
)
topk_log_probs = next_token_log_probs.topk(top_k, -1)
next_tokens = sample_from_values(
torch.exp(topk_log_probs[0] / temperature), topk_log_probs[1]
).squeeze(1)
return next_tokens
class CAIFSampler:
def __init__(self, classifier_name, lm_tokenizer, device, invert_cls_probs: bool = False):
self.device = device
self.classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(
classifier_name
)
self.classifier_model = (
transformers.AutoModelForSequenceClassification.from_pretrained(
classifier_name
).to(device)
)
self.classifier_model.eval()
self.lm_tokenizer = lm_tokenizer
self.invert_cls_probs = invert_cls_probs
def __call__(
self,
input_ids,
output_logis,
top_k,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=None,
act_type: str = "softmax",
**kwargs
):
target_cls_id = kwargs["target_cls_id"]
next_token_logits = output_logis[:, -1]
next_token_log_probs = F.log_softmax(
next_token_logits, dim=-1
)
(next_token_unnormalized_probs, topk_indices,) = self.get_unnormalized_probs(
input_ids,
next_token_log_probs,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=caif_tokens_num,
target_cls_id=target_cls_id
)
topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
next_tokens = sample_from_values(
topk_probs[0],
torch.take_along_dim(topk_indices, topk_probs[1], dim=1),
).squeeze(1)
return next_tokens
def get_unnormalized_probs(
self,
input_ids,
next_token_log_probs,
temperature,
top_k_classifier,
classifier_weight,
target_cls_id: int = 0,
caif_tokens_num=None
):
if classifier_weight == 0.0:
raise ValueError(
"classifier weight equal to 0 is not supported for CAIF Sampling"
)
top_next_token_log_probs = next_token_log_probs.topk(top_k_classifier, -1)
classifier_input = torch.cat(
[
input_ids.unsqueeze(1).repeat(1, top_k_classifier, 1).flatten(0, 1),
top_next_token_log_probs[1].view(-1).unsqueeze(-1),
],
-1,
)
classifier_input = [
self.lm_tokenizer.decode(sequence, skip_special_tokens=True)
for sequence in classifier_input
]
if self.invert_cls_probs:
classifier_log_probs = torch.log(
1 - self.get_classifier_probs(
classifier_input, caif_tokens_num=caif_tokens_num
).view(-1, top_k_classifier)
)
else:
classifier_log_probs = self.get_classifier_log_probs(
classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id,
).view(-1, top_k_classifier)
next_token_probs = torch.exp(
(top_next_token_log_probs[0] + classifier_weight * classifier_log_probs)
/ temperature
)
return next_token_probs, top_next_token_log_probs[1]
def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
input_ids = self.classifier_tokenizer(
input, padding=True, return_tensors="pt"
).to(self.device)
if caif_tokens_num is not None:
input_ids["input_ids"] = input_ids["input_ids"][:, -caif_tokens_num:]
if "attention_mask" in input_ids.keys():
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
if "token_type_ids" in input_ids.keys():
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
return torch.log(torch.sigmoid(logits))
def get_classifier_probs(self, input, caif_tokens_num=None):
input_ids = self.classifier_tokenizer(
input, padding=True, return_tensors="pt"
).to(self.device)
if caif_tokens_num is not None:
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
if "attention_mask" in input_ids.keys():
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
return torch.sigmoid(logits)
|