Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
def apply_temperature(scores, tempt): | |
if tempt > 0: | |
scores = scores / tempt | |
return scores | |
def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1): | |
if top_p > 0 and top_p < 1: | |
sorted_logits, sorted_indices = torch.sort(scores, descending=False) | |
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | |
if min_tokens_to_keep > 1: | |
# Keep at least min_tokens_to_keep | |
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
1, sorted_indices, sorted_indices_to_remove | |
) | |
scores = scores.masked_fill(indices_to_remove, filter_value) | |
return scores | |
def apply_top_k(logits, top_k): | |
top_k = min(top_k, logits.size(-1)) # Safety check | |
if top_k > 0: | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None] | |
logits[indices_to_remove] = -float("Inf") | |
return logits | |
def apply_advanced_repetition_penalty( | |
input_ids, scores, penalty_range, penalty_slope, penalty | |
): | |
penalty_range = int(penalty_range) | |
clipped_penalty_range = min(input_ids.shape[-1], penalty_range) | |
if penalty != 1.0: | |
if penalty_range > 0: | |
if clipped_penalty_range < input_ids.shape[1]: | |
input_ids = input_ids[..., -clipped_penalty_range:] | |
if penalty_slope != 0: | |
_penalty = ( | |
torch.arange( | |
penalty_range, dtype=scores.dtype, device=scores.device | |
) | |
/ (penalty_range - 1) | |
) * 2.0 - 1 | |
_penalty = (penalty_slope * _penalty) / ( | |
1 + torch.abs(_penalty) * (penalty_slope - 1) | |
) | |
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1) | |
penalty = _penalty[..., -clipped_penalty_range:] | |
score = torch.gather(scores, 1, input_ids) | |
score = torch.where(score <= 0, score * penalty, score / penalty) | |
scores.scatter_(1, input_ids, score) | |
return scores | |
class LmGeneration: | |
def __init__(self, model, tokenizer): | |
self.model = model | |
self.tokenizer = tokenizer | |
def generate(self, args, prompts, cut_off=None, cut_off_times=1): | |
if cut_off is not None: | |
cut_off_times = [cut_off_times for i in range(len(prompts))] | |
batch = len(prompts) | |
assert batch <= args.batch_size | |
prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | |
min_prompt_len = min([len(x) for x in prompt_tokens]) | |
# max_prompt_len = max([len(x) for x in prompt_tokens]) | |
total_len = args.seq_length | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long() | |
for idx, t in enumerate(prompt_tokens): | |
tokens[idx, : len(t)] = torch.tensor(t).long() | |
mask = tokens != self.tokenizer.pad_token | |
start_pos = min_prompt_len | |
prev_pos = 0 | |
continue_exsample = [i for i in range(batch)] | |
with torch.no_grad(): | |
for cur_pos in range(start_pos, total_len): | |
logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float() | |
next_token_scores = apply_top_k(logits, top_k=args.top_k) | |
next_token_scores = apply_top_p(next_token_scores, args.top_p) | |
next_token_scores = apply_temperature(next_token_scores, args.temperature) | |
next_token_scores = apply_advanced_repetition_penalty( | |
tokens[continue_exsample, :cur_pos], | |
next_token_scores, | |
args.repetition_penalty_range, | |
args.repetition_penalty_slope, | |
args.repetition_penalty | |
) | |
scores = F.softmax(next_token_scores, dim=-1) | |
next_token = torch.multinomial(scores, num_samples=1).squeeze(1) | |
next_token = next_token.reshape(-1) | |
next_token = torch.where( | |
mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token | |
) | |
tokens[continue_exsample, cur_pos] = next_token | |
prev_pos = cur_pos | |
# remove eos examples. | |
continue_exsample = [] | |
for i, t in enumerate(tokens.tolist()): | |
try: | |
t.index(self.tokenizer.eos_token) | |
except ValueError: | |
if cut_off is not None: | |
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]: | |
if cut_off_times[i] == 1: | |
continue | |
else: | |
cut_off_times[i] -= 1 | |
continue_exsample.append(i) | |
if len(continue_exsample) == 0: | |
break | |
decoder = [] | |
for i, t in enumerate(tokens.tolist()): | |
t = t[: args.seq_length] | |
try: | |
t = t[: t.index(self.tokenizer.pad_token)] | |
t = t[: t.index(self.tokenizer.eos_token)] | |
except ValueError: | |
pass | |
decoder.append(self.tokenizer.decode(t)) | |
return decoder |