Spaces:
Runtime error
Runtime error
import re | |
import json | |
import records | |
from typing import List, Dict | |
from sqlalchemy.exc import SQLAlchemyError | |
from utils.sql.all_keywords import ALL_KEY_WORDS | |
class WTQDBEngine: | |
def __init__(self, fdb): | |
self.db = records.Database('sqlite:///{}'.format(fdb)) | |
self.conn = self.db.get_connection() | |
def execute_wtq_query(self, sql_query: str): | |
out = self.conn.query(sql_query) | |
results = out.all() | |
merged_results = [] | |
for i in range(len(results)): | |
merged_results.extend(results[i].values()) | |
return merged_results | |
def delete_rows(self, row_indices: List[int]): | |
sql_queries = [ | |
"delete from w where id == {}".format(row) for row in row_indices | |
] | |
for query in sql_queries: | |
self.conn.query(query) | |
def process_table_structure(_wtq_table_content: Dict, _add_all_column: bool = False): | |
# remove id and agg column | |
headers = [_.replace("\n", " ").lower() for _ in _wtq_table_content["headers"][2:]] | |
header_map = {} | |
for i in range(len(headers)): | |
header_map["c" + str(i + 1)] = headers[i] | |
header_types = _wtq_table_content["types"][2:] | |
all_headers = [] | |
all_header_types = [] | |
vertical_content = [] | |
for column_content in _wtq_table_content["contents"][2:]: | |
# only take the first one | |
if _add_all_column: | |
for i in range(len(column_content)): | |
column_alias = column_content[i]["col"] | |
# do not add the numbered column | |
if "_number" in column_alias: | |
continue | |
vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[i]["data"]]) | |
if "_" in column_alias: | |
first_slash_pos = column_alias.find("_") | |
column_name = header_map[column_alias[:first_slash_pos]] + " " + \ | |
column_alias[first_slash_pos + 1:].replace("_", " ") | |
else: | |
column_name = header_map[column_alias] | |
all_headers.append(column_name) | |
if column_content[i]["type"] == "TEXT": | |
all_header_types.append("text") | |
else: | |
all_header_types.append("number") | |
else: | |
vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[0]["data"]]) | |
row_content = list(map(list, zip(*vertical_content))) | |
if _add_all_column: | |
ret_header = all_headers | |
ret_types = all_header_types | |
else: | |
ret_header = headers | |
ret_types = header_types | |
return { | |
"header": ret_header, | |
"rows": row_content, | |
"types": ret_types, | |
"alias": list(_wtq_table_content["is_list"].keys()) | |
} | |
def retrieve_wtq_query_answer(_engine, _table_content, _sql_struct: List): | |
# do not append id / agg | |
headers = _table_content["header"] | |
def flatten_sql(_ex_sql_struct: List): | |
# [ "Keyword", "select", [] ], [ "Column", "c4", [] ] | |
_encode_sql = [] | |
_execute_sql = [] | |
for _ex_tuple in _ex_sql_struct: | |
keyword = str(_ex_tuple[1]) | |
# upper the keywords. | |
if keyword in ALL_KEY_WORDS: | |
keyword = str(keyword).upper() | |
# extra column, which we do not need in result | |
if keyword == "w" or keyword == "from": | |
# add 'FROM w' make it executable | |
_encode_sql.append(keyword) | |
elif re.fullmatch(r"c\d+(_.+)?", keyword): | |
# only take the first part | |
index_key = int(keyword.split("_")[0][1:]) - 1 | |
# wrap it with `` to make it executable | |
_encode_sql.append("`{}`".format(headers[index_key])) | |
else: | |
_encode_sql.append(keyword) | |
# c4_list, replace it with the original one | |
if "_address" in keyword or "_list" in keyword: | |
keyword = re.findall(r"c\d+", keyword)[0] | |
_execute_sql.append(keyword) | |
return " ".join(_execute_sql), " ".join(_encode_sql) | |
_exec_sql_str, _encode_sql_str = flatten_sql(_sql_struct) | |
try: | |
_sql_answers = _engine.execute_wtq_query(_exec_sql_str) | |
except SQLAlchemyError as e: | |
_sql_answers = [] | |
_norm_sql_answers = [str(_).replace("\n", " ") for _ in _sql_answers if _ is not None] | |
if "none" in _norm_sql_answers: | |
_norm_sql_answers = [] | |
return _encode_sql_str, _norm_sql_answers, _exec_sql_str | |
def _load_table_w_page(table_path, page_title_path=None) -> dict: | |
""" | |
attention: the table_path must be the .tsv path. | |
Load the WikiTableQuestion from csv file. Result in a dict format like: | |
{"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]} | |
""" | |
from utils.utils import _load_table | |
table_item = _load_table(table_path) | |
# Load page title | |
if not page_title_path: | |
page_title_path = table_path.replace("csv", "page").replace(".tsv", ".json") | |
with open(page_title_path, "r") as f: | |
page_title = json.load(f)['title'] | |
table_item['page_title'] = page_title | |
return table_item | |