"""Visualizer for TAPAS Implementation heavily based on `EncodingVisualizer` from `tokenizers.tools`. """ import os from typing import Any, List, Dict from collections import defaultdict import pandas as pd from transformers import TapasTokenizer dirname = os.path.dirname(__file__) css_filename = os.path.join(dirname, "tapas-styles.css") with open(css_filename) as f: css = f.read() def HTMLBody(table_html: str, css_styles: str = css) -> str: """ Generates the full html with css from a list of html spans Args: table_html (str): The html string of the table css_styles (str): CSS styling to be embedded inline Returns: :obj:`str`: An HTML string with style markup """ return f"""
{table_html}
""" class TapasVisualizer: def __init__(self, tokenizer: TapasTokenizer) -> None: self.tokenizer = tokenizer def normalize_token_str(self, token_str: str) -> str: # Normalize subword tokens to org subword str return token_str.replace("##", "") def style_span(self, span_text: str, css_classes: List[str]) -> str: css = f'''class="{' '.join(css_classes)}"''' return f"{span_text}" def text_to_html(self, org_text: str, tokens: List[str]) -> str: """Create html based on the original text and its tokens. Note: The tokens need to be in same order as in the original text Args: org_text (str): Original string before tokenization tokens (List[str]): The tokens of org_text Returns: str: html with styling for the tokens """ if len(tokens) == 0: print(f"Empty tokens for: {org_text}") return "" cur_token_id = 0 cur_token = self.normalize_token_str(tokens[cur_token_id]) # Loop through each character next_start = 0 last_end = 0 spans = [] while next_start < len(org_text): candidate = org_text[next_start : next_start + len(cur_token)] # The tokenizer performs lowercasing; so check against lowercase if candidate.lower() == cur_token: if last_end != next_start: # There was token-less text (probably whitespace) # in the middle spans.append( self.style_span(org_text[last_end:next_start], ["non-token"]) ) odd_or_even = "even-token" if cur_token_id % 2 == 0 else "odd-token" spans.append(self.style_span(candidate, ["token", odd_or_even])) next_start += len(cur_token) last_end = next_start cur_token_id += 1 if cur_token_id >= len(tokens): break cur_token = self.normalize_token_str(tokens[cur_token_id]) else: next_start += 1 if last_end != len(org_text): spans.append(self.style_span(org_text[last_end:next_start], ["non-token"])) return spans def cells_to_html( self, cell_vals: List[List[str]], cell_tokens: Dict, row_id_start: int = 0, cell_element: str = "td", cumulative_cnt: int = 0, table_html: str = "", ) -> str: for row_id, row in enumerate(cell_vals, start=row_id_start): row_html = "" row_token_cnt = 0 for col_id, cell in enumerate(row, start=1): cur_cell_tokens = cell_tokens[(row_id, col_id)] span_htmls = self.text_to_html(cell, cur_cell_tokens) cell_html = "".join(span_htmls) row_html += f"<{cell_element}>{cell_html}" row_token_cnt += len(cur_cell_tokens) cumulative_cnt += row_token_cnt cnt_html = ( f'' f'{self.style_span(str(cumulative_cnt), ["non-token", "count"])}' "" f'' f'{self.style_span(f"<+{row_token_cnt}", ["non-token", "count"])}' "" ) row_html = cnt_html + row_html table_html += f"{row_html}" return table_html, cumulative_cnt def __call__(self, table: pd.DataFrame) -> Any: tokenized = self.tokenizer(table) cell_tokens = defaultdict(list) for id_ind, input_id in enumerate(tokenized["input_ids"]): input_id = int(input_id) # 'prev_label', 'column_rank', 'inv_column_rank', 'numeric_relation' # not required segment_id, col_id, row_id, *_ = tokenized["token_type_ids"][id_ind] token_text = self.tokenizer._convert_id_to_token(input_id) if int(segment_id) == 1: cell_tokens[(row_id, col_id)].append(token_text) table_html, cumulative_cnt = self.cells_to_html( cell_vals=[table.columns], cell_tokens=cell_tokens, row_id_start=0, cell_element="th", cumulative_cnt=0, table_html="", ) table_html, cumulative_cnt = self.cells_to_html( cell_vals=table.values, cell_tokens=cell_tokens, row_id_start=1, cell_element="td", cumulative_cnt=cumulative_cnt, table_html=table_html, ) top_label = self.style_span("#Tokens", ["count"]) top_label_cnt = self.style_span(f"(Total: {cumulative_cnt})", ["count"]) table_html = ( '' f'{top_label}' f'{top_label_cnt}' "" f"{table_html}" ) table_html = f"{table_html}
" return HTMLBody(table_html)