Pierce Maloney commited on
Commit
6d8b690
1 Parent(s): 4c4f932

truncating earliest tokens if longer than 4092

Browse files
Files changed (1) hide show
  1. handler.py +9 -3
handler.py CHANGED
@@ -4,7 +4,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, Stopping
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
- # Preload all the elements you are going to need at inference.
8
  tokenizer = AutoTokenizer.from_pretrained(path)
9
  tokenizer.pad_token = tokenizer.eos_token
10
  self.model = AutoModelForCausalLM.from_pretrained(path)
@@ -31,11 +30,18 @@ class EndpointHandler():
31
  bad_words_ids.extend(additional_bad_words_ids)
32
 
33
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
 
 
34
 
35
- # Generate text using model.generate
 
 
 
 
 
36
  generated_ids = self.model.generate(
37
  input_ids,
38
- max_length=input_ids.shape[1] + 50, # 50 new tokens
39
  bad_words_ids=bad_words_ids,
40
  temperature=1,
41
  top_k=40,
 
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
 
7
  tokenizer = AutoTokenizer.from_pretrained(path)
8
  tokenizer.pad_token = tokenizer.eos_token
9
  self.model = AutoModelForCausalLM.from_pretrained(path)
 
30
  bad_words_ids.extend(additional_bad_words_ids)
31
 
32
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
33
+ max_generation_length = 75 # Desired number of tokens to generate
34
+ max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
35
 
36
+ # Truncate input_ids to the most recent tokens that fit within the max_input_length
37
+ if input_ids.shape[1] > max_input_length:
38
+ input_ids = input_ids[:, -max_input_length:]
39
+
40
+ max_length = input_ids.shape[1] + max_generation_length
41
+
42
  generated_ids = self.model.generate(
43
  input_ids,
44
+ max_length=max_length, # 50 new tokens
45
  bad_words_ids=bad_words_ids,
46
  temperature=1,
47
  top_k=40,