Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Tuple | |
import torch | |
from outlines.samplers import MultinomialSampler | |
logger = logging.getLogger(__name__) | |
class PenalizedMultinomialSampler(MultinomialSampler): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.penalized_tokens_group: list[torch.IntTensor] = [] | |
self.max_repeats_per_token_group: list[int] = [] | |
self.repeats_per_token_group: list[int] = [] | |
self.token_id_to_tokens_groups: list[list[int]] = [] | |
def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None: | |
max_token_ids = max(token_ids) | |
if max_token_ids >= len(self.token_id_to_tokens_groups): | |
self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)] | |
for token_id in token_ids: | |
self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group)) | |
self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32)) | |
self.max_repeats_per_token_group.append(max_repeats) | |
self.repeats_per_token_group.append(0) | |
def __call__( | |
self, | |
next_token_logits: torch.DoubleTensor, | |
sequence_weights: torch.DoubleTensor, | |
rng: torch.Generator, | |
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]: | |
"""Call the multinomial sampler. | |
Parameters | |
---------- | |
next_token_logits | |
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the | |
probability distribution of the next token over the vocabulary. | |
sequence_weights | |
A tensor of shape ``(n_seqs,)`` that represents the cumulative | |
weight of each sequence. | |
rng | |
A random number generator. | |
Returns | |
------- | |
A tuple with an array that contains the ids of the sampled tokens of | |
shape ``(n_seqs, 1)``, an array that contains the ancestors of each | |
sampled id of shape ``(n_seqs,)`` and an array that contains the updated | |
cumulative weights of each sequence of shape ``(n_seqs,)``. | |
""" | |
if sequence_weights.min() == sequence_weights.max() == 0: | |
self.repeats_per_token_group = [0] * len(self.repeats_per_token_group) | |
else: | |
for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group): | |
if repeats_per_token_group >= max_repeats_per_token_group: | |
penalty = torch.zeros_like(next_token_logits) | |
penalty[:, penalized_tokens_group] = - torch.inf | |
next_token_logits = next_token_logits + penalty | |
next_token_ids, ancestors, weights = super().__call__( | |
next_token_logits=next_token_logits, | |
sequence_weights=sequence_weights, | |
rng=rng | |
) | |
for next_token_id in next_token_ids.cpu(): | |
if next_token_id < len(self.token_id_to_tokens_groups): | |
for token_group in self.token_id_to_tokens_groups[next_token_id]: | |
self.repeats_per_token_group[token_group] += 1 | |
return next_token_ids, ancestors, weights | |