Keyven commited on
Commit
dfc4fa8
β€’
1 Parent(s): a076c9d

Update response function

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -48,10 +48,7 @@ def format_text(text):
48
 
49
  def get_chat_response(chatbot, task_history):
50
  """Generate a response using the model."""
51
- chat_query = chatbot[-1][0]
52
- query = task_history[-1][0]
53
  history_cp = copy.deepcopy(task_history)
54
- full_response = ""
55
 
56
  history_filter = []
57
  pic_idx = 1
@@ -68,6 +65,7 @@ def get_chat_response(chatbot, task_history):
68
  history, message = history_filter[:-1], history_filter[-1][0]
69
 
70
  inputs = tokenizer.encode_plus(message, return_tensors='pt')
 
71
  outputs = model.generate(inputs['input_ids'], max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
72
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
 
@@ -77,6 +75,7 @@ def get_chat_response(chatbot, task_history):
77
  return chatbot, task_history
78
 
79
 
 
80
  def handle_text_input(history, task_history, text):
81
  """Handle text input from the user."""
82
  task_text = text
 
48
 
49
  def get_chat_response(chatbot, task_history):
50
  """Generate a response using the model."""
 
 
51
  history_cp = copy.deepcopy(task_history)
 
52
 
53
  history_filter = []
54
  pic_idx = 1
 
65
  history, message = history_filter[:-1], history_filter[-1][0]
66
 
67
  inputs = tokenizer.encode_plus(message, return_tensors='pt')
68
+ inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the same device as the model
69
  outputs = model.generate(inputs['input_ids'], max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
70
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
 
 
75
  return chatbot, task_history
76
 
77
 
78
+
79
  def handle_text_input(history, task_history, text):
80
  """Handle text input from the user."""
81
  task_text = text