from typing import List, Dict import pandas as pd import recognizers_suite from recognizers_suite import Culture import re import unicodedata from fuzzywuzzy import fuzz from utils.sql.extraction_from_sql import * from utils.sql.all_keywords import ALL_KEY_WORDS culture = Culture.English def str_normalize(user_input, recognition_types=None): """A string normalizer which recognize and normalize value based on recognizers_suite""" user_input = str(user_input) user_input = user_input.replace("\\n", "; ") def replace_by_idx_pairs(orig_str, strs_to_replace, idx_pairs): assert len(strs_to_replace) == len(idx_pairs) last_end = 0 to_concat = [] for idx_pair, str_to_replace in zip(idx_pairs, strs_to_replace): to_concat.append(orig_str[last_end:idx_pair[0]]) to_concat.append(str_to_replace) last_end = idx_pair[1] to_concat.append(orig_str[last_end:]) return ''.join(to_concat) if recognition_types is None: recognition_types = ["datetime", "number", # "ordinal", # "percentage", # "age", # "currency", # "dimension", # "temperature", ] for recognition_type in recognition_types: if re.match("\d+/\d+", user_input): # avoid calculating str as 1991/92 continue recognized_list = getattr(recognizers_suite, "recognize_{}".format(recognition_type))(user_input, culture) # may match multiple parts strs_to_replace = [] idx_pairs = [] for recognized in recognized_list: if not recognition_type == 'datetime': recognized_value = recognized.resolution['value'] if str(recognized_value).startswith("P"): # if the datetime is a period: continue else: strs_to_replace.append(recognized_value) idx_pairs.append((recognized.start, recognized.end + 1)) else: if recognized.resolution: # in some cases, this variable could be none. if len(recognized.resolution['values']) == 1: strs_to_replace.append( recognized.resolution['values'][0]['timex']) # We use timex as normalization idx_pairs.append((recognized.start, recognized.end + 1)) if len(strs_to_replace) > 0: user_input = replace_by_idx_pairs(user_input, strs_to_replace, idx_pairs) if re.match("(.*)-(.*)-(.*) 00:00:00", user_input): user_input = user_input[:-len("00:00:00") - 1] # '2008-04-13 00:00:00' -> '2008-04-13' return user_input def prepare_df_for_neuraldb_from_table(table: Dict, add_row_id=True, normalize=True, lower_case=True): header, rows = table['header'], table['rows'] if add_row_id and 'row_id' not in header: header = ["row_id"] + header rows = [["{}".format(i)] + row for i, row in enumerate(rows)] if normalize: df = convert_df_type(pd.DataFrame(data=rows, columns=header), lower_case=lower_case) else: df = pd.DataFrame(data=rows, columns=header) return df def convert_df_type(df: pd.DataFrame, lower_case=True): """ A simple converter of dataframe data type from string to int/float/datetime. """ def get_table_content_in_column(table): if isinstance(table, pd.DataFrame): header = table.columns.tolist() rows = table.values.tolist() else: # Standard table dict format header, rows = table['header'], table['rows'] all_col_values = [] for i in range(len(header)): one_col_values = [] for _row in rows: one_col_values.append(_row[i]) all_col_values.append(one_col_values) return all_col_values # Rename empty columns new_columns = [] for idx, header in enumerate(df.columns): if header == '': new_columns.append('FilledColumnName') # Fixme: give it a better name when all finished! else: new_columns.append(header) df.columns = new_columns # Rename duplicate columns new_columns = [] for idx, header in enumerate(df.columns): if header in new_columns: new_header, suffix = header, 2 while new_header in new_columns: new_header = header + '_' + str(suffix) suffix += 1 new_columns.append(new_header) else: new_columns.append(header) df.columns = new_columns # Recognize null values like "-" null_tokens = ['', '-', '/'] for header in df.columns: df[header] = df[header].map(lambda x: str(None) if x in null_tokens else x) # Convert the null values in digit column to "NaN" all_col_values = get_table_content_in_column(df) for col_i, one_col_values in enumerate(all_col_values): all_number_flag = True for row_i, cell_value in enumerate(one_col_values): try: float(cell_value) except Exception as e: if not cell_value in [str(None), str(None).lower()]: # None or none all_number_flag = False if all_number_flag: _header = df.columns[col_i] df[_header] = df[_header].map(lambda x: "NaN" if x in [str(None), str(None).lower()] else x) # Normalize cell values. for header in df.columns: df[header] = df[header].map(lambda x: str_normalize(x)) # Strip the mis-added "01-01 00:00:00" all_col_values = get_table_content_in_column(df) for col_i, one_col_values in enumerate(all_col_values): all_with_00_00_00 = True all_with_01_00_00_00 = True all_with_01_01_00_00_00 = True for row_i, cell_value in enumerate(one_col_values): if not str(cell_value).endswith(" 00:00:00"): all_with_00_00_00 = False if not str(cell_value).endswith("-01 00:00:00"): all_with_01_00_00_00 = False if not str(cell_value).endswith("-01-01 00:00:00"): all_with_01_01_00_00_00 = False if all_with_01_01_00_00_00: _header = df.columns[col_i] df[_header] = df[_header].map(lambda x: x[:-len("-01-01 00:00:00")]) continue if all_with_01_00_00_00: _header = df.columns[col_i] df[_header] = df[_header].map(lambda x: x[:-len("-01 00:00:00")]) continue if all_with_00_00_00: _header = df.columns[col_i] df[_header] = df[_header].map(lambda x: x[:-len(" 00:00:00")]) continue # Do header and cell value lower case if lower_case: new_columns = [] for header in df.columns: lower_header = str(header).lower() if lower_header in new_columns: new_header, suffix = lower_header, 2 while new_header in new_columns: new_header = lower_header + '-' + str(suffix) suffix += 1 new_columns.append(new_header) else: new_columns.append(lower_header) df.columns = new_columns for header in df.columns: # df[header] = df[header].map(lambda x: str(x).lower()) df[header] = df[header].map(lambda x: str(x).lower().strip()) # Recognize header type for header in df.columns: float_able = False int_able = False datetime_able = False # Recognize int & float type try: df[header].astype("float") float_able = True except: pass if float_able: try: if all(df[header].astype("float") == df[header].astype(int)): int_able = True except: pass if float_able: if int_able: df[header] = df[header].astype(int) else: df[header] = df[header].astype(float) # Recognize datetime type try: df[header].astype("datetime64") datetime_able = True except: pass if datetime_able: df[header] = df[header].astype("datetime64") return df def normalize(x): """ Normalize string. """ # Copied from WikiTableQuestions dataset official evaluator. if x is None: return None # Remove diacritics x = ''.join(c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn') # Normalize quotes and dashes x = re.sub("[‘’´`]", "'", x) x = re.sub("[“”]", "\"", x) x = re.sub("[‐‑‒–—−]", "-", x) while True: old_x = x # Remove citations x = re.sub("((?= fuzz_threshold: matched_cells.append((cell, fuzz_score)) matched_cells = sorted(matched_cells, key=lambda x: x[1], reverse=True) return matched_cells def _check_valid_fuzzy_match(value_str, matched_cell): """ Check if the fuzzy match is valid, now considering: 1. The number/date should not be disturbed, but adding new number or deleting number is valid. """ number_pattern = "[+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?" numbers_in_value = re.findall(number_pattern, value_str) numbers_in_matched_cell = re.findall(number_pattern, matched_cell) try: numbers_in_value = [float(num.replace(',', '')) for num in numbers_in_value] except: print(f"Can't convert number string {numbers_in_value} into float in _check_valid_fuzzy_match().") try: numbers_in_matched_cell = [float(num.replace(',', '')) for num in numbers_in_matched_cell] except: print( f"Can't convert number string {numbers_in_matched_cell} into float in _check_valid_fuzzy_match().") numbers_in_value = set(numbers_in_value) numbers_in_matched_cell = set(numbers_in_matched_cell) if numbers_in_value.issubset(numbers_in_matched_cell) or numbers_in_matched_cell.issubset(numbers_in_value): return True else: return False # Drop trailing '\n```', a pattern that may appear in Codex SQL generation sql_str = sql_str.rstrip('```').rstrip('\n') # Replace QA module with placeholder qa_pattern = "QA\(.+?;.*?`.+?`.*?\)" qas = re.findall(qa_pattern, sql_str) for idx, qa in enumerate(qas): sql_str = sql_str.replace(qa, f"placeholder{idx}") # Parse and replace SQL value with table contents sql_tokens = tokenize(sql_str) sql_template_tokens = extract_partial_template_from_sql(sql_str) # Fix 'between' keyword bug in parsing templates fixed_sql_template_tokens = [] sql_tok_bias = 0 for idx, sql_templ_tok in enumerate(sql_template_tokens): sql_tok = sql_tokens[idx + sql_tok_bias] if sql_tok == 'between' and sql_templ_tok == '[WHERE_OP]': fixed_sql_template_tokens.extend(['[WHERE_OP]', '[VALUE]', 'and']) sql_tok_bias += 2 # pass '[VALUE]', 'and' else: fixed_sql_template_tokens.append(sql_templ_tok) sql_template_tokens = fixed_sql_template_tokens for idx, tok in enumerate(sql_tokens): if tok in ALL_KEY_WORDS: sql_tokens[idx] = tok.upper() if verbose: print(sql_tokens) print(sql_template_tokens) assert len(sql_tokens) == len(sql_template_tokens) value_indices = [idx for idx in range(len(sql_template_tokens)) if sql_template_tokens[idx] == '[VALUE]'] for value_idx in value_indices: # Skip the value if the where condition column is QA module if value_idx >= 2 and sql_tokens[value_idx - 2].startswith('placeholder'): continue value_str = sql_tokens[value_idx] # Drop \"\" for fuzzy match is_string = False if value_str[0] == "\"" and value_str[-1] == "\"": value_str = value_str[1:-1] is_string = True # If already fuzzy match, skip if value_str[0] == '%' or value_str[-1] == '%': continue value_str = value_str.lower() # Fuzzy Match matched_cells = _get_matched_cells(value_str, df) if verbose: print(matched_cells) new_value_str = value_str if matched_cells: # new_value_str = matched_cells[0][0] for matched_cell, fuzz_score in matched_cells: if _check_valid_fuzzy_match(value_str, matched_cell): new_value_str = matched_cell if verbose and new_value_str != value_str: print("\tfuzzy match replacing!", value_str, '->', matched_cell, f'fuzz_score:{fuzz_score}') break if is_string: new_value_str = f"\"{new_value_str}\"" sql_tokens[value_idx] = new_value_str # Compose new sql string # Clean column name in SQL since columns may have been tokenized in the postprocessing, e.g., (ppp) -> ( ppp ) new_sql_str = ' '.join(sql_tokens) sql_columns = re.findall('`\s(.*?)\s`', new_sql_str) for sql_col in sql_columns: matched_columns = [] for col in df.columns: score = fuzz.ratio(sql_col.lower(), col) if score == 100: matched_columns = [(col, score)] break if score >= 80: matched_columns.append((col, score)) matched_columns = sorted(matched_columns, key=lambda x: x[1], reverse=True) if matched_columns: matched_col = matched_columns[0][0] new_sql_str = new_sql_str.replace(f"` {sql_col} `", f"`{matched_col}`") else: new_sql_str = new_sql_str.replace(f"` {sql_col} `", f"`{sql_col}`") # Restore QA modules for idx, qa in enumerate(qas): new_sql_str = new_sql_str.replace(f"placeholder{idx}", qa) # Fix '<>' when composing the new sql new_sql_str = new_sql_str.replace('< >', '<>') return new_sql_str sql_str = basic_fix(sql_str, list(df.columns), table_title) if process_program_with_fuzzy_match_on_db: try: sql_str = fuzzy_match_process(sql_str, df, verbose) except: pass return sql_str