Binder / utils /normalizer.py
Timothyxxx
Init
f6f97d8
from typing import List, Dict
import pandas as pd
import recognizers_suite
from recognizers_suite import Culture
import re
import unicodedata
from fuzzywuzzy import fuzz
from utils.sql.extraction_from_sql import *
from utils.sql.all_keywords import ALL_KEY_WORDS
culture = Culture.English
def str_normalize(user_input, recognition_types=None):
"""A string normalizer which recognize and normalize value based on recognizers_suite"""
user_input = str(user_input)
user_input = user_input.replace("\\n", "; ")
def replace_by_idx_pairs(orig_str, strs_to_replace, idx_pairs):
assert len(strs_to_replace) == len(idx_pairs)
last_end = 0
to_concat = []
for idx_pair, str_to_replace in zip(idx_pairs, strs_to_replace):
to_concat.append(orig_str[last_end:idx_pair[0]])
to_concat.append(str_to_replace)
last_end = idx_pair[1]
to_concat.append(orig_str[last_end:])
return ''.join(to_concat)
if recognition_types is None:
recognition_types = ["datetime",
"number",
# "ordinal",
# "percentage",
# "age",
# "currency",
# "dimension",
# "temperature",
]
for recognition_type in recognition_types:
if re.match("\d+/\d+", user_input):
# avoid calculating str as 1991/92
continue
recognized_list = getattr(recognizers_suite, "recognize_{}".format(recognition_type))(user_input,
culture) # may match multiple parts
strs_to_replace = []
idx_pairs = []
for recognized in recognized_list:
if not recognition_type == 'datetime':
recognized_value = recognized.resolution['value']
if str(recognized_value).startswith("P"):
# if the datetime is a period:
continue
else:
strs_to_replace.append(recognized_value)
idx_pairs.append((recognized.start, recognized.end + 1))
else:
if recognized.resolution: # in some cases, this variable could be none.
if len(recognized.resolution['values']) == 1:
strs_to_replace.append(
recognized.resolution['values'][0]['timex']) # We use timex as normalization
idx_pairs.append((recognized.start, recognized.end + 1))
if len(strs_to_replace) > 0:
user_input = replace_by_idx_pairs(user_input, strs_to_replace, idx_pairs)
if re.match("(.*)-(.*)-(.*) 00:00:00", user_input):
user_input = user_input[:-len("00:00:00") - 1]
# '2008-04-13 00:00:00' -> '2008-04-13'
return user_input
def prepare_df_for_neuraldb_from_table(table: Dict, add_row_id=True, normalize=True, lower_case=True):
header, rows = table['header'], table['rows']
if add_row_id and 'row_id' not in header:
header = ["row_id"] + header
rows = [["{}".format(i)] + row for i, row in enumerate(rows)]
if normalize:
df = convert_df_type(pd.DataFrame(data=rows, columns=header), lower_case=lower_case)
else:
df = pd.DataFrame(data=rows, columns=header)
return df
def convert_df_type(df: pd.DataFrame, lower_case=True):
"""
A simple converter of dataframe data type from string to int/float/datetime.
"""
def get_table_content_in_column(table):
if isinstance(table, pd.DataFrame):
header = table.columns.tolist()
rows = table.values.tolist()
else:
# Standard table dict format
header, rows = table['header'], table['rows']
all_col_values = []
for i in range(len(header)):
one_col_values = []
for _row in rows:
one_col_values.append(_row[i])
all_col_values.append(one_col_values)
return all_col_values
# Rename empty columns
new_columns = []
for idx, header in enumerate(df.columns):
if header == '':
new_columns.append('FilledColumnName') # Fixme: give it a better name when all finished!
else:
new_columns.append(header)
df.columns = new_columns
# Rename duplicate columns
new_columns = []
for idx, header in enumerate(df.columns):
if header in new_columns:
new_header, suffix = header, 2
while new_header in new_columns:
new_header = header + '_' + str(suffix)
suffix += 1
new_columns.append(new_header)
else:
new_columns.append(header)
df.columns = new_columns
# Recognize null values like "-"
null_tokens = ['', '-', '/']
for header in df.columns:
df[header] = df[header].map(lambda x: str(None) if x in null_tokens else x)
# Convert the null values in digit column to "NaN"
all_col_values = get_table_content_in_column(df)
for col_i, one_col_values in enumerate(all_col_values):
all_number_flag = True
for row_i, cell_value in enumerate(one_col_values):
try:
float(cell_value)
except Exception as e:
if not cell_value in [str(None), str(None).lower()]:
# None or none
all_number_flag = False
if all_number_flag:
_header = df.columns[col_i]
df[_header] = df[_header].map(lambda x: "NaN" if x in [str(None), str(None).lower()] else x)
# Normalize cell values.
for header in df.columns:
df[header] = df[header].map(lambda x: str_normalize(x))
# Strip the mis-added "01-01 00:00:00"
all_col_values = get_table_content_in_column(df)
for col_i, one_col_values in enumerate(all_col_values):
all_with_00_00_00 = True
all_with_01_00_00_00 = True
all_with_01_01_00_00_00 = True
for row_i, cell_value in enumerate(one_col_values):
if not str(cell_value).endswith(" 00:00:00"):
all_with_00_00_00 = False
if not str(cell_value).endswith("-01 00:00:00"):
all_with_01_00_00_00 = False
if not str(cell_value).endswith("-01-01 00:00:00"):
all_with_01_01_00_00_00 = False
if all_with_01_01_00_00_00:
_header = df.columns[col_i]
df[_header] = df[_header].map(lambda x: x[:-len("-01-01 00:00:00")])
continue
if all_with_01_00_00_00:
_header = df.columns[col_i]
df[_header] = df[_header].map(lambda x: x[:-len("-01 00:00:00")])
continue
if all_with_00_00_00:
_header = df.columns[col_i]
df[_header] = df[_header].map(lambda x: x[:-len(" 00:00:00")])
continue
# Do header and cell value lower case
if lower_case:
new_columns = []
for header in df.columns:
lower_header = str(header).lower()
if lower_header in new_columns:
new_header, suffix = lower_header, 2
while new_header in new_columns:
new_header = lower_header + '-' + str(suffix)
suffix += 1
new_columns.append(new_header)
else:
new_columns.append(lower_header)
df.columns = new_columns
for header in df.columns:
# df[header] = df[header].map(lambda x: str(x).lower())
df[header] = df[header].map(lambda x: str(x).lower().strip())
# Recognize header type
for header in df.columns:
float_able = False
int_able = False
datetime_able = False
# Recognize int & float type
try:
df[header].astype("float")
float_able = True
except:
pass
if float_able:
try:
if all(df[header].astype("float") == df[header].astype(int)):
int_able = True
except:
pass
if float_able:
if int_able:
df[header] = df[header].astype(int)
else:
df[header] = df[header].astype(float)
# Recognize datetime type
try:
df[header].astype("datetime64")
datetime_able = True
except:
pass
if datetime_able:
df[header] = df[header].astype("datetime64")
return df
def normalize(x):
""" Normalize string. """
# Copied from WikiTableQuestions dataset official evaluator.
if x is None:
return None
# Remove diacritics
x = ''.join(c for c in unicodedata.normalize('NFKD', x)
if unicodedata.category(c) != 'Mn')
# Normalize quotes and dashes
x = re.sub("[‘’´`]", "'", x)
x = re.sub("[“”]", "\"", x)
x = re.sub("[‐‑‒–—−]", "-", x)
while True:
old_x = x
# Remove citations
x = re.sub("((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$", "", x.strip())
# Remove details in parenthesis
x = re.sub("(?<!^)( \([^)]*\))*$", "", x.strip())
# Remove outermost quotation mark
x = re.sub('^"([^"]*)"$', r'\1', x.strip())
if x == old_x:
break
# Remove final '.'
if x and x[-1] == '.':
x = x[:-1]
# Collapse whitespaces and convert to lower case
x = re.sub('\s+', ' ', x, flags=re.U).lower().strip()
return x
def post_process_sql(sql_str, df, table_title=None, process_program_with_fuzzy_match_on_db=True, verbose=False):
"""Post process SQL: including basic fix and further fuzzy match on cell and SQL to process"""
def basic_fix(sql_str, all_headers, table_title=None):
def finditer(sub_str: str, mother_str: str):
result = []
start_index = 0
while True:
start_index = mother_str.find(sub_str, start_index, -1)
if start_index == -1:
break
end_idx = start_index + len(sub_str)
result.append((start_index, end_idx))
start_index = end_idx
return result
if table_title:
sql_str = sql_str.replace("FROM " + table_title, "FROM w")
sql_str = sql_str.replace("FROM " + table_title.lower(), "FROM w")
"""Case 1: Fix the `` missing. """
# Remove the null header.
while '' in all_headers:
all_headers.remove('')
# Remove the '\n' in header.
# This is because the WikiTQ won't actually show the str in two lines,
# they use '\n' to mean that, and display it in the same line when print.
sql_str = sql_str.replace("\\n", "\n")
sql_str = sql_str.replace("\n", "\\n")
# Add `` in SQL.
all_headers.sort(key=lambda x: len(x), reverse=True)
have_matched = [0 for i in range(len(sql_str))]
# match quotation
idx_s_single_quotation = [_ for _ in range(1, len(sql_str)) if
sql_str[_] in ["\'"] and sql_str[_ - 1] not in ["\'"]]
idx_s_double_quotation = [_ for _ in range(1, len(sql_str)) if
sql_str[_] in ["\""] and sql_str[_ - 1] not in ["\""]]
for idx_s in [idx_s_single_quotation, idx_s_double_quotation]:
if len(idx_s) % 2 == 0:
for idx in range(int(len(idx_s) / 2)):
start_idx = idx_s[idx * 2]
end_idx = idx_s[idx * 2 + 1]
have_matched[start_idx: end_idx] = [2 for _ in range(end_idx - start_idx)]
# match headers
for header in all_headers:
if (header in sql_str) and (header not in ALL_KEY_WORDS):
all_matched_of_this_header = finditer(header, sql_str)
for matched_of_this_header in all_matched_of_this_header:
start_idx, end_idx = matched_of_this_header
if all(have_matched[start_idx: end_idx]) == 0 and (not sql_str[start_idx - 1] == "`") and (
not sql_str[end_idx] == "`"):
have_matched[start_idx: end_idx] = [1 for _ in range(end_idx - start_idx)]
# a bit ugly, but anyway.
# re-compose sql from the matched idx.
start_have_matched = [0] + have_matched
end_have_matched = have_matched + [0]
start_idx_s = [idx - 1 for idx in range(1, len(start_have_matched)) if
start_have_matched[idx - 1] == 0 and start_have_matched[idx] == 1]
end_idx_s = [idx for idx in range(len(end_have_matched) - 1) if
end_have_matched[idx] == 1 and end_have_matched[idx + 1] == 0]
assert len(start_idx_s) == len(end_idx_s)
spans = []
current_idx = 0
for start_idx, end_idx in zip(start_idx_s, end_idx_s):
spans.append(sql_str[current_idx:start_idx])
spans.append(sql_str[start_idx:end_idx + 1])
current_idx = end_idx + 1
spans.append(sql_str[current_idx:])
sql_str = '`'.join(spans)
return sql_str
def fuzzy_match_process(sql_str, df, verbose=False):
"""
Post-process SQL by fuzzy matching value with table contents.
"""
def _get_matched_cells(value_str, df, fuzz_threshold=70):
"""
Get matched table cells with value token.
"""
matched_cells = []
for row_id, row in df.iterrows():
for cell in row:
cell = str(cell)
fuzz_score = fuzz.ratio(value_str, cell)
if fuzz_score == 100:
matched_cells = [(cell, fuzz_score)]
return matched_cells
if fuzz_score >= fuzz_threshold:
matched_cells.append((cell, fuzz_score))
matched_cells = sorted(matched_cells, key=lambda x: x[1], reverse=True)
return matched_cells
def _check_valid_fuzzy_match(value_str, matched_cell):
"""
Check if the fuzzy match is valid, now considering:
1. The number/date should not be disturbed, but adding new number or deleting number is valid.
"""
number_pattern = "[+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?"
numbers_in_value = re.findall(number_pattern, value_str)
numbers_in_matched_cell = re.findall(number_pattern, matched_cell)
try:
numbers_in_value = [float(num.replace(',', '')) for num in numbers_in_value]
except:
print(f"Can't convert number string {numbers_in_value} into float in _check_valid_fuzzy_match().")
try:
numbers_in_matched_cell = [float(num.replace(',', '')) for num in numbers_in_matched_cell]
except:
print(
f"Can't convert number string {numbers_in_matched_cell} into float in _check_valid_fuzzy_match().")
numbers_in_value = set(numbers_in_value)
numbers_in_matched_cell = set(numbers_in_matched_cell)
if numbers_in_value.issubset(numbers_in_matched_cell) or numbers_in_matched_cell.issubset(numbers_in_value):
return True
else:
return False
# Drop trailing '\n```', a pattern that may appear in Codex SQL generation
sql_str = sql_str.rstrip('```').rstrip('\n')
# Replace QA module with placeholder
qa_pattern = "QA\(.+?;.*?`.+?`.*?\)"
qas = re.findall(qa_pattern, sql_str)
for idx, qa in enumerate(qas):
sql_str = sql_str.replace(qa, f"placeholder{idx}")
# Parse and replace SQL value with table contents
sql_tokens = tokenize(sql_str)
sql_template_tokens = extract_partial_template_from_sql(sql_str)
# Fix 'between' keyword bug in parsing templates
fixed_sql_template_tokens = []
sql_tok_bias = 0
for idx, sql_templ_tok in enumerate(sql_template_tokens):
sql_tok = sql_tokens[idx + sql_tok_bias]
if sql_tok == 'between' and sql_templ_tok == '[WHERE_OP]':
fixed_sql_template_tokens.extend(['[WHERE_OP]', '[VALUE]', 'and'])
sql_tok_bias += 2 # pass '[VALUE]', 'and'
else:
fixed_sql_template_tokens.append(sql_templ_tok)
sql_template_tokens = fixed_sql_template_tokens
for idx, tok in enumerate(sql_tokens):
if tok in ALL_KEY_WORDS:
sql_tokens[idx] = tok.upper()
if verbose:
print(sql_tokens)
print(sql_template_tokens)
assert len(sql_tokens) == len(sql_template_tokens)
value_indices = [idx for idx in range(len(sql_template_tokens)) if sql_template_tokens[idx] == '[VALUE]']
for value_idx in value_indices:
# Skip the value if the where condition column is QA module
if value_idx >= 2 and sql_tokens[value_idx - 2].startswith('placeholder'):
continue
value_str = sql_tokens[value_idx]
# Drop \"\" for fuzzy match
is_string = False
if value_str[0] == "\"" and value_str[-1] == "\"":
value_str = value_str[1:-1]
is_string = True
# If already fuzzy match, skip
if value_str[0] == '%' or value_str[-1] == '%':
continue
value_str = value_str.lower()
# Fuzzy Match
matched_cells = _get_matched_cells(value_str, df)
if verbose:
print(matched_cells)
new_value_str = value_str
if matched_cells:
# new_value_str = matched_cells[0][0]
for matched_cell, fuzz_score in matched_cells:
if _check_valid_fuzzy_match(value_str, matched_cell):
new_value_str = matched_cell
if verbose and new_value_str != value_str:
print("\tfuzzy match replacing!", value_str, '->', matched_cell, f'fuzz_score:{fuzz_score}')
break
if is_string:
new_value_str = f"\"{new_value_str}\""
sql_tokens[value_idx] = new_value_str
# Compose new sql string
# Clean column name in SQL since columns may have been tokenized in the postprocessing, e.g., (ppp) -> ( ppp )
new_sql_str = ' '.join(sql_tokens)
sql_columns = re.findall('`\s(.*?)\s`', new_sql_str)
for sql_col in sql_columns:
matched_columns = []
for col in df.columns:
score = fuzz.ratio(sql_col.lower(), col)
if score == 100:
matched_columns = [(col, score)]
break
if score >= 80:
matched_columns.append((col, score))
matched_columns = sorted(matched_columns, key=lambda x: x[1], reverse=True)
if matched_columns:
matched_col = matched_columns[0][0]
new_sql_str = new_sql_str.replace(f"` {sql_col} `", f"`{matched_col}`")
else:
new_sql_str = new_sql_str.replace(f"` {sql_col} `", f"`{sql_col}`")
# Restore QA modules
for idx, qa in enumerate(qas):
new_sql_str = new_sql_str.replace(f"placeholder{idx}", qa)
# Fix '<>' when composing the new sql
new_sql_str = new_sql_str.replace('< >', '<>')
return new_sql_str
sql_str = basic_fix(sql_str, list(df.columns), table_title)
if process_program_with_fuzzy_match_on_db:
try:
sql_str = fuzzy_match_process(sql_str, df, verbose)
except:
pass
return sql_str