Update app.py
Browse files
app.py
CHANGED
@@ -140,7 +140,30 @@ def sqlquery(input): #, history=[]):
|
|
140 |
# Get a batch of records
|
141 |
batch_data = table[start_idx:end_idx]
|
142 |
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
for idx, record in enumerate(batch_data):
|
146 |
# Maintain conversation context by appending history
|
@@ -168,7 +191,7 @@ def sqlquery(input): #, history=[]):
|
|
168 |
# Update conversation history
|
169 |
conversation_history.append("User: " + record["question"])
|
170 |
conversation_history.append("Bot: " + response)
|
171 |
-
|
172 |
|
173 |
# ==========================================================================
|
174 |
'''
|
|
|
140 |
# Get a batch of records
|
141 |
batch_data = table[start_idx:end_idx]
|
142 |
|
143 |
+
# Tokenize the batch
|
144 |
+
tokenized_batch = sql_tokenizer.batch_encode_plus(
|
145 |
+
batch_data, padding=True, truncation=True, return_tensors="pt"
|
146 |
+
)
|
147 |
+
|
148 |
+
# Perform inference
|
149 |
+
with torch.no_grad():
|
150 |
+
output = sql_model.generate(
|
151 |
+
input_ids=tokenized_batch["input_ids"],
|
152 |
+
max_length=1024,
|
153 |
+
pad_token_id=sql_tokenizer.eos_token_id,
|
154 |
+
)
|
155 |
+
|
156 |
+
# Decode the output and process the responses
|
157 |
+
responses = [sql_tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
|
158 |
+
|
159 |
+
conversation_history.append("User: " + record["question"])
|
160 |
+
for response in enumerate(responses):
|
161 |
+
# Update conversation history
|
162 |
+
conversation_history.append("Bot: " + response)
|
163 |
+
|
164 |
+
'''
|
165 |
+
|
166 |
+
= []
|
167 |
|
168 |
for idx, record in enumerate(batch_data):
|
169 |
# Maintain conversation context by appending history
|
|
|
191 |
# Update conversation history
|
192 |
conversation_history.append("User: " + record["question"])
|
193 |
conversation_history.append("Bot: " + response)
|
194 |
+
'''
|
195 |
|
196 |
# ==========================================================================
|
197 |
'''
|