17Goutham commited on
Commit
8a57a60
1 Parent(s): 3dd5459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -7
app.py CHANGED
@@ -1,13 +1,39 @@
1
- from app_tapex import execute_query
 
 
 
2
  import gradio as gr
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def main():
6
- description = "Querying a csv using TAPEX model. You can ask a question about tabular data. TAPAS model " \
7
- "will produce the result. Finetuned TAPEX model runs on max 5000 rows and 20 columns data. " \
8
- "A sample data of shopify store sales is provided"
9
 
10
- article = "<p style='text-align: center'><a href='https://unscrambl.com/' target='_blank'>Unscrambl</a> | <a href='https://huggingface.co/google/tapas-base-finetuned-wtq' target='_blank'>TAPAS Model</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=abaranovskij_tablequery' alt='visitor badge'></center>"
11
 
12
  iface = gr.Interface(fn=execute_query,
13
  inputs=[gr.Textbox(label="Search query"),
@@ -22,6 +48,5 @@ def main():
22
  # iface.launch(server_name="0.0.0.0", server_port=7000)
23
  iface.launch(enable_queue=True)
24
 
25
-
26
  if __name__ == "__main__":
27
- main()
 
1
+ from transformers import TapexTokenizer, BartForConditionalGeneration
2
+ import pandas as pd
3
+ import datetime
4
+ import torch
5
  import gradio as gr
6
 
7
+ def execute_query(query, csv_file):
8
+ a = datetime.datetime.now()
9
+
10
+ table = pd.read_csv(csv_file.name, delimiter=",")
11
+ table = table.astype(str)
12
+
13
+ model_name = "microsoft/tapex-large-finetuned-wtq"
14
+ model = BartForConditionalGeneration.from_pretrained(model_name)
15
+ tokenizer = TapexTokenizer.from_pretrained(model_name)
16
+
17
+ queries = [query]
18
+
19
+ encoding = tokenizer(table=table, query=queries, padding=True, return_tensors="pt", truncation=True)
20
+ outputs = model.generate(**encoding)
21
+ ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
22
+
23
+ query_result = {
24
+ "query": query,
25
+ "answer": ans[0]
26
+ }
27
+
28
+ b = datetime.datetime.now()
29
+ print(b - a)
30
+
31
+ return query_result, table
32
 
33
  def main():
34
+ description = "Querying a CSV using the TAPEX model. You can ask a question about tabular data, and the TAPEX model will produce the result. The finetuned TAPEX model runs on data with a maximum of 5000 rows and 20 columns. A sample dataset of Shopify store sales is provided."
 
 
35
 
36
+ article = "<p style='text-align: center'><a href='https://unscrambl.com/' target='_blank'>Unscrambl</a> | <a href='https://huggingface.co/microsoft/tapex-large-finetuned-wtq' target='_blank'>TAPEX Model</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=abaranovskij_tablequery' alt='visitor badge'></center>"
37
 
38
  iface = gr.Interface(fn=execute_query,
39
  inputs=[gr.Textbox(label="Search query"),
 
48
  # iface.launch(server_name="0.0.0.0", server_port=7000)
49
  iface.launch(enable_queue=True)
50
 
 
51
  if __name__ == "__main__":
52
+ main()