teaevo commited on
Commit
a46b806
1 Parent(s): 8233187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -140,7 +140,30 @@ def sqlquery(input): #, history=[]):
140
  # Get a batch of records
141
  batch_data = table[start_idx:end_idx]
142
 
143
- batch_responses = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  for idx, record in enumerate(batch_data):
146
  # Maintain conversation context by appending history
@@ -168,7 +191,7 @@ def sqlquery(input): #, history=[]):
168
  # Update conversation history
169
  conversation_history.append("User: " + record["question"])
170
  conversation_history.append("Bot: " + response)
171
-
172
 
173
  # ==========================================================================
174
  '''
 
140
  # Get a batch of records
141
  batch_data = table[start_idx:end_idx]
142
 
143
+ # Tokenize the batch
144
+ tokenized_batch = sql_tokenizer.batch_encode_plus(
145
+ batch_data, padding=True, truncation=True, return_tensors="pt"
146
+ )
147
+
148
+ # Perform inference
149
+ with torch.no_grad():
150
+ output = sql_model.generate(
151
+ input_ids=tokenized_batch["input_ids"],
152
+ max_length=1024,
153
+ pad_token_id=sql_tokenizer.eos_token_id,
154
+ )
155
+
156
+ # Decode the output and process the responses
157
+ responses = [sql_tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
158
+
159
+ conversation_history.append("User: " + record["question"])
160
+ for response in enumerate(responses):
161
+ # Update conversation history
162
+ conversation_history.append("Bot: " + response)
163
+
164
+ '''
165
+
166
+ = []
167
 
168
  for idx, record in enumerate(batch_data):
169
  # Maintain conversation context by appending history
 
191
  # Update conversation history
192
  conversation_history.append("User: " + record["question"])
193
  conversation_history.append("Bot: " + response)
194
+ '''
195
 
196
  # ==========================================================================
197
  '''