|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def top_k_top_p_filtering( |
|
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 |
|
): |
|
""" |
|
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. |
|
|
|
Args: |
|
logits (torch.Tensor): Logits distribution with shape (batch size, vocabulary size). |
|
top_k (int, optional): Keep only top k tokens with highest probability (top-k filtering). |
|
Set to 0 to disable. Defaults to 0. |
|
top_p (float, optional): Keep the top tokens with a cumulative probability >= top_p (nucleus filtering). |
|
Must be between 0 and 1, inclusive. Defaults to 1.0. |
|
filter_value (float, optional): The value to assign to filtered logits. Defaults to -float('Inf'). |
|
min_tokens_to_keep (int, optional): Ensure that at least this number of tokens are kept per batch example. |
|
Defaults to 1. |
|
|
|
Returns: |
|
torch.Tensor: The filtered logits. |
|
""" |
|
""" |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
Make sure we keep at least min_tokens_to_keep per batch example in the output |
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
if top_k > 0: |
|
|
|
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) |
|
indices_to_remove = logits < torch.topk(logits, top_k).values[..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p < 1.0: |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
if min_tokens_to_keep > 1: |
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices.scatter( |
|
1, sorted_indices, sorted_indices_to_remove |
|
) |
|
logits[indices_to_remove] = filter_value |
|
|
|
return logits |
|
|
|
|
|
def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0): |
|
""" |
|
Perform top-k and top-p sampling on logits. |
|
|
|
Args: |
|
logits (torch.Tensor): The logits to sample from. |
|
top_k (int, optional): The number of highest probability tokens to keep for top-k filtering. |
|
Must be a positive integer. Defaults to 50. |
|
top_p (float, optional): The cumulative probability threshold for nucleus sampling. |
|
Must be between 0 and 1. Defaults to 1.0. |
|
temperature (float, optional): The scaling factor to adjust the logits distribution. |
|
Must be strictly positive. Defaults to 1.0. |
|
|
|
Returns: |
|
torch.Tensor: The sampled token. |
|
""" |
|
|
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
|
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
|
|
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
|
return token |
|
|