ClickbaitFighter / utils.py
Iker's picture
Use Zero
50c5b0d
raw
history blame contribute delete
No virus
1.91 kB
import logging
from typing import List
import torch
from transformers import (
LogitsProcessor,
)
class StopAfterTokenIsGenerated(LogitsProcessor):
def __init__(self, stops: List[torch.tensor], eos_token_id: int):
super().__init__()
self.stops = stops
self.eos_token_id = eos_token_id
logging.info(f"Stopping criteria words ids: {self.stops}")
self.first_batch = True
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""
if self.first_batch:
self.first_batch = False
return scores
for seq_no, seq in enumerate(input_ids):
# logging.info(seq_no)
for stop in self.stops:
stop = stop.to(device=seq.device, dtype=seq.dtype)
if (
len(seq) >= len(stop)
and torch.all((stop == seq[-len(stop) :])).item()
):
scores[seq_no, :] = -float("inf")
scores[seq_no, self.eos_token_id] = 0
logging.info(f"Stopping criteria found: {stop}")
break
return scores
def reset(self):
self.first_batch = True