Linly-ChatFlow / generate.py
wmpscc's picture
Update generate.py
45d104f
raw
history blame
5.89 kB
import torch
import torch.nn.functional as F
def apply_temperature(scores, tempt):
if tempt > 0:
scores = scores / tempt
return scores
def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1):
if top_p > 0 and top_p < 1:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores
def apply_top_k(logits, top_k):
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")
return logits
def apply_advanced_repetition_penalty(
input_ids, scores, penalty_range, penalty_slope, penalty
):
penalty_range = int(penalty_range)
clipped_penalty_range = min(input_ids.shape[-1], penalty_range)
if penalty != 1.0:
if penalty_range > 0:
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
if penalty_slope != 0:
_penalty = (
torch.arange(
penalty_range, dtype=scores.dtype, device=scores.device
)
/ (penalty_range - 1)
) * 2.0 - 1
_penalty = (penalty_slope * _penalty) / (
1 + torch.abs(_penalty) * (penalty_slope - 1)
)
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
penalty = _penalty[..., -clipped_penalty_range:]
score = torch.gather(scores, 1, input_ids)
score = torch.where(score <= 0, score * penalty, score / penalty)
scores.scatter_(1, input_ids, score)
return scores
class LmGeneration:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(self, args, prompts, cut_off=None, cut_off_times=1):
if cut_off is not None:
cut_off_times = [cut_off_times for i in range(len(prompts))]
batch = len(prompts)
assert batch <= args.batch_size
prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
min_prompt_len = min([len(x) for x in prompt_tokens])
# max_prompt_len = max([len(x) for x in prompt_tokens])
total_len = args.seq_length
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long()
for idx, t in enumerate(prompt_tokens):
tokens[idx, : len(t)] = torch.tensor(t).long()
mask = tokens != self.tokenizer.pad_token
start_pos = min_prompt_len
prev_pos = 0
continue_exsample = [i for i in range(batch)]
with torch.no_grad():
for cur_pos in range(start_pos, total_len):
logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float()
next_token_scores = apply_top_k(logits, top_k=args.top_k)
next_token_scores = apply_top_p(next_token_scores, args.top_p)
next_token_scores = apply_temperature(next_token_scores, args.temperature)
next_token_scores = apply_advanced_repetition_penalty(
tokens[continue_exsample, :cur_pos],
next_token_scores,
args.repetition_penalty_range,
args.repetition_penalty_slope,
args.repetition_penalty
)
scores = F.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(scores, num_samples=1).squeeze(1)
next_token = next_token.reshape(-1)
next_token = torch.where(
mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token
)
tokens[continue_exsample, cur_pos] = next_token
prev_pos = cur_pos
# remove eos examples.
continue_exsample = []
for i, t in enumerate(tokens.tolist()):
try:
t.index(self.tokenizer.eos_token)
except ValueError:
if cut_off is not None:
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
if cut_off_times[i] == 1:
continue
else:
cut_off_times[i] -= 1
continue_exsample.append(i)
if len(continue_exsample) == 0:
break
decoder = []
for i, t in enumerate(tokens.tolist()):
t = t[: args.seq_length]
try:
t = t[: t.index(self.tokenizer.pad_token)]
t = t[: t.index(self.tokenizer.eos_token)]
except ValueError:
pass
decoder.append(self.tokenizer.decode(t))
return decoder