llemma_7b / handler.py
Pierce Maloney
truncating earliest tokens if longer than 4092
6d8b690
raw
history blame
2.84 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
class EndpointHandler():
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = tokenizer
self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
# 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
# 13 is a newline character
# [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
# [2087, 29885, 4430, 29889] is "Admitted."
bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
bad_words_ids.extend(additional_bad_words_ids)
input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
max_generation_length = 75 # Desired number of tokens to generate
max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
# Truncate input_ids to the most recent tokens that fit within the max_input_length
if input_ids.shape[1] > max_input_length:
input_ids = input_ids[:, -max_input_length:]
max_length = input_ids.shape[1] + max_generation_length
generated_ids = self.model.generate(
input_ids,
max_length=max_length, # 50 new tokens
bad_words_ids=bad_words_ids,
temperature=1,
top_k=40,
stopping_criteria=self.stopping_criteria,
)
generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
return prediction
class StopAtPeriodCriteria(StoppingCriteria):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Decode the last generated token to text
last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
# Check if the decoded text ends with a period
return '.' in last_token_text