Pierce Maloney commited on
Commit
355a0ec
1 Parent(s): a36be93
Files changed (1) hide show
  1. handler.py +3 -5
handler.py CHANGED
@@ -53,13 +53,11 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, Stopping
53
  # prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
54
  # return prediction
55
 
56
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
57
-
58
  class EndpointHandler():
59
  def __init__(self, path=""):
60
  self.model_path = path
61
  tokenizer = AutoTokenizer.from_pretrained(path)
62
- tokenizer.pad_token = self.tokenizer.eos_token
63
  self.tokenizer = tokenizer
64
  # Initialize the pipeline for text generation
65
  self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
@@ -82,10 +80,10 @@ class EndpointHandler():
82
  # Generate text using the pipeline
83
  generation_kwargs = {
84
  "max_length": 75, # Adjust as needed
85
- "temperature": 1,
86
  "top_k": 40,
87
  "bad_words_ids": bad_words_ids,
88
- "pad_token_id": self.tokenizer.eos_token_id # Ensure padding with EOS token
89
  }
90
  generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
91
 
 
53
  # prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
54
  # return prediction
55
 
 
 
56
  class EndpointHandler():
57
  def __init__(self, path=""):
58
  self.model_path = path
59
  tokenizer = AutoTokenizer.from_pretrained(path)
60
+ tokenizer.pad_token = tokenizer.eos_token
61
  self.tokenizer = tokenizer
62
  # Initialize the pipeline for text generation
63
  self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
 
80
  # Generate text using the pipeline
81
  generation_kwargs = {
82
  "max_length": 75, # Adjust as needed
83
+ "temperature": 0.7,
84
  "top_k": 40,
85
  "bad_words_ids": bad_words_ids,
86
+ # "pad_token_id": self.tokenizer.eos_token_id # Ensure padding with EOS token
87
  }
88
  generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
89