|
from typing import Any, Dict, Optional, List |
|
import torch |
|
from transformers import GenerationMixin |
|
from transformers import AutoTokenizer |
|
import re |
|
import traceback |
|
|
|
|
|
class WebGenerationMixin(GenerationMixin): |
|
def _update_model_kwargs_for_generation( |
|
self, |
|
outputs, |
|
model_kwargs: Dict[str, Any], |
|
is_encoder_decoder: bool = False, |
|
standardize_cache_format: bool = False, |
|
) -> Dict[str, Any]: |
|
|
|
|
|
model_kwargs["past_key_values"] = self._extract_past_from_model_output( |
|
outputs, standardize_cache_format=standardize_cache_format |
|
) |
|
if getattr(outputs, "state", None) is not None: |
|
model_kwargs["state"] = outputs.state |
|
|
|
|
|
if "token_type_ids" in model_kwargs: |
|
token_type_ids = model_kwargs["token_type_ids"] |
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
|
|
if not is_encoder_decoder: |
|
|
|
if 'web_attention_mask' not in model_kwargs: |
|
attention_mask = model_kwargs["attention_mask"] |
|
model_kwargs['web_attention_mask'] = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0) |
|
|
|
if "attention_mask" in model_kwargs: |
|
attention_mask = model_kwargs["attention_mask"] |
|
model_kwargs["attention_mask"] = torch.cat( |
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
|
) |
|
|
|
model_kwargs['html_tree'] = outputs.html_tree |
|
|
|
else: |
|
|
|
if "decoder_attention_mask" in model_kwargs: |
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"] |
|
model_kwargs["decoder_attention_mask"] = torch.cat( |
|
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], |
|
dim=-1, |
|
) |
|
|
|
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: |
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
|
return model_kwargs |
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
raise NotImplementedError( |
|
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" |
|
f" enable beam search for {self.__class__}" |
|
) |
|
|
|
|
|
class TreeNode(): |
|
def __init__(self,content: list, idx: int): |
|
self.open_tag: List[str] = content |
|
self.end_tag: Optional[List[str]] = None |
|
self.self_closing_tag: Optional[List[str]] = None |
|
self.text = "" |
|
|
|
self.name: Optional[str] = None |
|
self.parent: Optional['TreeNode'] = None |
|
|
|
self.open_tag_range: Optional[List[int]] = None |
|
self.end_tag_range: Optional[List[int]] = None |
|
self.text_range = [-1,-1] |
|
self.self_closing_tag_range = [-1,-1] |
|
|
|
self.idx: int = idx |
|
self.children: List['TreeNode'] = [] |
|
|
|
|
|
def partially_open(self): |
|
if not self.open_tag: return False |
|
if any('<' in s for s in self.open_tag) and not any('>' in s for s in self.open_tag): |
|
return True |
|
return False |
|
|
|
def add_child(self,child): |
|
assert child.parent is None, "Child already has a parent" |
|
assert child not in self.children, "Child is already in children list" |
|
child.parent = self |
|
self.children.append(child) |
|
|
|
def get_range(self): |
|
if self.text: |
|
return list(range(*self.text_range)) |
|
elif self.self_closing_tag: |
|
return list(range(*self.self_closing_tag_range)) |
|
else: |
|
attn_range = [] |
|
if self.open_tag_range: |
|
attn_range += list(range(*self.open_tag_range)) |
|
if self.end_tag_range: |
|
attn_range += list(range(*self.end_tag_range)) |
|
return attn_range |
|
|
|
def __repr__(self): |
|
return f"Node(name='{self.open_tag}', idx = {self.idx})" |
|
|
|
def print_tree(self, level=0, input_ids = None, tokenizer = None): |
|
if level == 0: |
|
print("--------") |
|
indent = " " * level |
|
if self.text: |
|
print(f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()}, level = {level} ") |
|
elif self.self_closing_tag: |
|
print(f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()}, level = {level} ") |
|
elif self.open_tag: |
|
print(f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()}, level = {level} ") |
|
for child in self.children: |
|
child.print_tree(level + 1, input_ids, tokenizer) |
|
if self.end_tag: |
|
print(f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()}, level = {level} ") |
|
else: |
|
for child in self.children: |
|
child.print_tree(level + 1, input_ids, tokenizer) |
|
if level == 0: |
|
print("--------") |
|
|
|
def get_tree(self, level=0, input_ids = None, tokenizer=None): |
|
tree_str = "" |
|
|
|
indent = " " * level |
|
if self.text: |
|
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()} \n" |
|
elif self.self_closing_tag: |
|
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()} \n" |
|
elif self.open_tag: |
|
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()} \n" |
|
for child in self.children: |
|
tree_str+=child.get_tree(level + 1, input_ids, tokenizer) |
|
if self.end_tag: |
|
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()} \n" |
|
else: |
|
for child in self.children: |
|
tree_str+=child.get_tree(level + 1, input_ids, tokenizer) |
|
|
|
return tree_str |
|
|
|
|
|
class TreeBuilder(): |
|
def __init__(self, tokenizer: AutoTokenizer = None, root: TreeNode = None, cur_node: TreeNode = None): |
|
self.tokenizer = tokenizer |
|
self.root = TreeNode(None, 0) |
|
self.cur_node = self.root |
|
self.buffer = [] |
|
self.buffer_start_index = 0 |
|
self.idx = 0 |
|
self.full_attention_list= None |
|
self.web_attention_mask = None |
|
self.input_ids = None |
|
self.void_elements = [ |
|
"area", |
|
"base", |
|
"br", |
|
"col", |
|
"embed", |
|
"hr", |
|
"img", |
|
"input", |
|
"link", |
|
"meta", |
|
"param", |
|
"source", |
|
"track", |
|
"wbr" |
|
] |
|
|
|
def is_empty(self): |
|
return self.root == None |
|
|
|
def in_buffer(self, text): |
|
if len(self.buffer) == 0: |
|
return False |
|
return any(text in s for s in self.buffer) |
|
|
|
def find_buffer(self, text): |
|
|
|
for index, s in enumerate(self.buffer): |
|
if text in s: |
|
return index |
|
return -1 |
|
|
|
|
|
def extract_open_tag_name(self,buffer): |
|
input_string = self.tokenizer.convert_tokens_to_string(buffer) |
|
match = re.search(r'<\s*(\w+)(?:\s+[^>]*)?>', input_string) |
|
if match: |
|
return match.group(1) |
|
return None |
|
|
|
def extract_close_tag_name(self,buffer): |
|
|
|
|
|
input_string = self.tokenizer.convert_tokens_to_string(buffer) |
|
match = re.search(r'</\s*(\w+)(?:\s+[^>]*)?>', input_string) |
|
if match: |
|
return match.group(1) |
|
return None |
|
|
|
def is_not_empty_buffer(self): |
|
return self.tokenizer.convert_tokens_to_string(self.buffer).strip() != '' |
|
|
|
def get_parent_and_siblings_attention_range(self): |
|
attn_range = [] |
|
if self.cur_node.parent: |
|
parent = self.cur_node.parent |
|
if parent.open_tag_range: |
|
attn_range += list(range(*parent.open_tag_range)) |
|
for child in parent.children: |
|
if child is not self.cur_node: |
|
if child.open_tag and child.end_tag: |
|
attn_range += list(range(*child.open_tag_range)) |
|
attn_range += list(range(*child.end_tag_range)) |
|
elif child.text: |
|
attn_range += list(range(*child.text_range)) |
|
elif child.self_closing_tag: |
|
attn_range += list(range(*child.self_closing_tag_range)) |
|
else: |
|
raise Exception(f"??? line 151, get p and s attention range") |
|
|
|
return attn_range |
|
|
|
def update_buffer(self, cur_decoded_token): |
|
|
|
assert isinstance(cur_decoded_token,list), f"{cur_decoded_token}" |
|
self.buffer+=cur_decoded_token |
|
assert isinstance(cur_decoded_token[0],str) |
|
|
|
try: |
|
|
|
if self.in_buffer('</' ) and self.in_buffer('>') and self.find_buffer('</') <= self.find_buffer('>'): |
|
close_tag_name = self.extract_close_tag_name(self.buffer) |
|
|
|
if self.cur_node.open_tag and not self.cur_node.end_tag: |
|
assert close_tag_name == self.extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---" |
|
elif self.cur_node.text or self.cur_node.self_closing_tag or self.cur_node.end_tag: |
|
content = None |
|
if self.cur_node.text: content = self.cur_node.text |
|
elif self.cur_node.self_closing_tag: content = self.cur_node.self_closing_tag |
|
elif self.cur_node.end_tag: content = self.cur_node.end_tag |
|
self.root.print_tree(0,None,self.tokenizer) |
|
raise Exception(f"This should never happen\n {content}, buffer is {self.buffer}") |
|
|
|
|
|
else: |
|
raise Exception(f"having end tag without having an open tag\n {self.cur_node.text}") |
|
|
|
self.cur_node.end_tag = self.buffer[:self.find_buffer('>')+1] |
|
self.cur_node.end_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
|
self.buffer_start_index += self.find_buffer('>')+1 |
|
self.buffer = self.buffer[self.find_buffer('>')+1:] |
|
|
|
elif self.in_buffer('</'): |
|
if self.cur_node.open_tag and not self.cur_node.end_tag: |
|
pass |
|
elif self.cur_node.text or self.cur_node.self_closing_tag or (self.cur_node.open_tag and self.cur_node.end_tag): |
|
cur_end_tag_index = self.find_buffer('</') |
|
|
|
if self.cur_node.text: |
|
self.cur_node.text += self.buffer[:cur_end_tag_index] |
|
self.cur_node.text_range[1] += len(self.buffer[:cur_end_tag_index]) |
|
elif self.cur_node.self_closing_tag: |
|
self.cur_node.self_closing_tag += self.buffer[:cur_end_tag_index] |
|
self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_end_tag_index]) |
|
else: |
|
self.cur_node.end_tag += self.buffer[:cur_end_tag_index] |
|
self.cur_node.end_tag_range[1] += len(self.buffer[:cur_end_tag_index]) |
|
self.buffer_start_index += len(self.buffer[:cur_end_tag_index]) |
|
self.buffer =self.buffer[cur_end_tag_index:] |
|
self.cur_node = self.cur_node.parent |
|
else: |
|
raise Exception(f"having end tag without having an open tag\n {self.cur_node.text} {self.cur_node} {self.cur_node.parent.open_tag}") |
|
|
|
elif self.in_buffer('<') and self.in_buffer('>'): |
|
|
|
if self.in_buffer('/>'): |
|
self.cur_node.open_tag = None |
|
self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1] |
|
self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
|
else: |
|
open_tag_name = self.extract_open_tag_name(self.buffer) |
|
if open_tag_name in self.void_elements: |
|
self.cur_node.open_tag = None |
|
self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1] |
|
self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
|
else: |
|
self.cur_node.open_tag = self.buffer[:self.find_buffer(">")+1] |
|
self.cur_node.open_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
|
|
|
self.buffer_start_index += self.find_buffer('>')+1 |
|
self.buffer = self.buffer[self.find_buffer(">")+1:] |
|
elif self.in_buffer('<'): |
|
if self.full_attention_list is None: |
|
self.full_attention_list = self.buffer[:-1] |
|
self.buffer = self.buffer[-1:] |
|
self.buffer_start_index = len(self.full_attention_list) |
|
else: |
|
cur_open_tag_index = self.find_buffer('<') |
|
|
|
if not self.cur_node.partially_open() and self.cur_node.open_tag: |
|
if self.cur_node.end_tag: |
|
self.cur_node.end_tag += self.buffer[:cur_open_tag_index] |
|
self.cur_node.end_tag_range[1] += len(self.buffer[:cur_open_tag_index]) |
|
self.buffer_start_index += len(self.buffer[:cur_open_tag_index]) |
|
self.buffer =self.buffer[cur_open_tag_index:] |
|
child_node = TreeNode(self.buffer, self.idx) |
|
if self.cur_node.parent: |
|
self.cur_node.parent.add_child(child_node) |
|
else: |
|
raise Exception(f"This should never happen, a html element with full open tag should have a parent, {self.cur_node.open_tag}") |
|
self.idx += 1 |
|
self.cur_node = child_node |
|
else: |
|
child_node = TreeNode(self.buffer, self.idx) |
|
self.cur_node.add_child(child_node) |
|
self.idx += 1 |
|
self.cur_node = child_node |
|
elif self.cur_node.text or self.cur_node.self_closing_tag: |
|
if self.cur_node.text: |
|
self.cur_node.text += self.buffer[:cur_open_tag_index] |
|
self.cur_node.text_range[1] += len(self.buffer[:cur_open_tag_index]) |
|
elif self.cur_node.self_closing_tag: |
|
self.cur_node.self_closing_tag += self.buffer[:cur_open_tag_index] |
|
self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_open_tag_index]) |
|
|
|
self.buffer_start_index += len(self.buffer[:cur_open_tag_index]) |
|
self.buffer =self.buffer[cur_open_tag_index:] |
|
child_node = TreeNode(self.buffer, self.idx) |
|
self.cur_node.parent.add_child(child_node) |
|
self.idx += 1 |
|
self.cur_node = child_node |
|
|
|
elif (self.cur_node.open_tag or self.cur_node.self_closing_tag) and not self.in_buffer('<') and self.is_not_empty_buffer(): |
|
child_node = TreeNode(None, self.idx) |
|
child_node.text = self.buffer |
|
child_node.text_range[0] = self.buffer_start_index |
|
child_node.text_range[1] = self.buffer_start_index + len(self.buffer) |
|
|
|
if self.cur_node.end_tag or self.cur_node.self_closing_tag: |
|
self.cur_node.parent.add_child(child_node) |
|
else: |
|
self.cur_node.add_child(child_node) |
|
|
|
self.idx += 1 |
|
self.cur_node = child_node |
|
self.buffer_start_index += len(self.buffer) |
|
self.buffer = [] |
|
|
|
elif self.cur_node.text and not self.in_buffer('<') and self.is_not_empty_buffer(): |
|
self.cur_node.text += self.buffer |
|
assert self.cur_node.text_range[0] != -1 and self.cur_node.text_range[1] != -1, f"self.cur_node.text_range[0] and [1] should not be -1 but: {self.cur_node.text_range[0]}, {self.cur_node.text_range[1]}" |
|
self.cur_node.text_range[1] += len(self.buffer) |
|
self.buffer_start_index += len(self.buffer) |
|
self.buffer =[] |
|
|
|
except Exception as e: |
|
traceback.format_exc() |
|
raise Exception(e) |
|
|
|
if self.full_attention_list is None: |
|
attn_range = list(range(len(self.buffer))) |
|
else: |
|
attn_range = list(range(len(self.full_attention_list))) + self.get_parent_and_siblings_attention_range() + self.cur_node.get_range() + [i + self.buffer_start_index for i in list(range(len(self.buffer)))] |
|
return attn_range |
|
|