Binder / utils /wtq /utils.py
Timothyxxx
Init
f6f97d8
import re
import json
import records
from typing import List, Dict
from sqlalchemy.exc import SQLAlchemyError
from utils.sql.all_keywords import ALL_KEY_WORDS
class WTQDBEngine:
def __init__(self, fdb):
self.db = records.Database('sqlite:///{}'.format(fdb))
self.conn = self.db.get_connection()
def execute_wtq_query(self, sql_query: str):
out = self.conn.query(sql_query)
results = out.all()
merged_results = []
for i in range(len(results)):
merged_results.extend(results[i].values())
return merged_results
def delete_rows(self, row_indices: List[int]):
sql_queries = [
"delete from w where id == {}".format(row) for row in row_indices
]
for query in sql_queries:
self.conn.query(query)
def process_table_structure(_wtq_table_content: Dict, _add_all_column: bool = False):
# remove id and agg column
headers = [_.replace("\n", " ").lower() for _ in _wtq_table_content["headers"][2:]]
header_map = {}
for i in range(len(headers)):
header_map["c" + str(i + 1)] = headers[i]
header_types = _wtq_table_content["types"][2:]
all_headers = []
all_header_types = []
vertical_content = []
for column_content in _wtq_table_content["contents"][2:]:
# only take the first one
if _add_all_column:
for i in range(len(column_content)):
column_alias = column_content[i]["col"]
# do not add the numbered column
if "_number" in column_alias:
continue
vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[i]["data"]])
if "_" in column_alias:
first_slash_pos = column_alias.find("_")
column_name = header_map[column_alias[:first_slash_pos]] + " " + \
column_alias[first_slash_pos + 1:].replace("_", " ")
else:
column_name = header_map[column_alias]
all_headers.append(column_name)
if column_content[i]["type"] == "TEXT":
all_header_types.append("text")
else:
all_header_types.append("number")
else:
vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[0]["data"]])
row_content = list(map(list, zip(*vertical_content)))
if _add_all_column:
ret_header = all_headers
ret_types = all_header_types
else:
ret_header = headers
ret_types = header_types
return {
"header": ret_header,
"rows": row_content,
"types": ret_types,
"alias": list(_wtq_table_content["is_list"].keys())
}
def retrieve_wtq_query_answer(_engine, _table_content, _sql_struct: List):
# do not append id / agg
headers = _table_content["header"]
def flatten_sql(_ex_sql_struct: List):
# [ "Keyword", "select", [] ], [ "Column", "c4", [] ]
_encode_sql = []
_execute_sql = []
for _ex_tuple in _ex_sql_struct:
keyword = str(_ex_tuple[1])
# upper the keywords.
if keyword in ALL_KEY_WORDS:
keyword = str(keyword).upper()
# extra column, which we do not need in result
if keyword == "w" or keyword == "from":
# add 'FROM w' make it executable
_encode_sql.append(keyword)
elif re.fullmatch(r"c\d+(_.+)?", keyword):
# only take the first part
index_key = int(keyword.split("_")[0][1:]) - 1
# wrap it with `` to make it executable
_encode_sql.append("`{}`".format(headers[index_key]))
else:
_encode_sql.append(keyword)
# c4_list, replace it with the original one
if "_address" in keyword or "_list" in keyword:
keyword = re.findall(r"c\d+", keyword)[0]
_execute_sql.append(keyword)
return " ".join(_execute_sql), " ".join(_encode_sql)
_exec_sql_str, _encode_sql_str = flatten_sql(_sql_struct)
try:
_sql_answers = _engine.execute_wtq_query(_exec_sql_str)
except SQLAlchemyError as e:
_sql_answers = []
_norm_sql_answers = [str(_).replace("\n", " ") for _ in _sql_answers if _ is not None]
if "none" in _norm_sql_answers:
_norm_sql_answers = []
return _encode_sql_str, _norm_sql_answers, _exec_sql_str
def _load_table_w_page(table_path, page_title_path=None) -> dict:
"""
attention: the table_path must be the .tsv path.
Load the WikiTableQuestion from csv file. Result in a dict format like:
{"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
"""
from utils.utils import _load_table
table_item = _load_table(table_path)
# Load page title
if not page_title_path:
page_title_path = table_path.replace("csv", "page").replace(".tsv", ".json")
with open(page_title_path, "r") as f:
page_title = json.load(f)['title']
table_item['page_title'] = page_title
return table_item