File size: 18,874 Bytes
2fed580 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 |
from typing import Any, Dict, Optional, List
import torch
from transformers import GenerationMixin
from transformers import AutoTokenizer
import re
import traceback
class WebGenerationMixin(GenerationMixin):
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if not is_encoder_decoder:
# update attention mask
if 'web_attention_mask' not in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs['web_attention_mask'] = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0)
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
model_kwargs['html_tree'] = outputs.html_tree
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
return model_kwargs
def _reorder_cache(self, past_key_values, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
f" enable beam search for {self.__class__}"
)
class TreeNode():
def __init__(self,content: list, idx: int):
self.open_tag: List[str] = content
self.end_tag: Optional[List[str]] = None
self.self_closing_tag: Optional[List[str]] = None
self.text = ""
self.name: Optional[str] = None
self.parent: Optional['TreeNode'] = None # Use 'TreeNode' as a string for forward reference
self.open_tag_range: Optional[List[int]] = None
self.end_tag_range: Optional[List[int]] = None
self.text_range = [-1,-1]
self.self_closing_tag_range = [-1,-1]
self.idx: int = idx
self.children: List['TreeNode'] = [] # List of TreeNode instances
def partially_open(self):
if not self.open_tag: return False
if any('<' in s for s in self.open_tag) and not any('>' in s for s in self.open_tag):
return True
return False
def add_child(self,child):
assert child.parent is None, "Child already has a parent"
assert child not in self.children, "Child is already in children list"
child.parent = self
self.children.append(child)
def get_range(self):
if self.text:
return list(range(*self.text_range))
elif self.self_closing_tag:
return list(range(*self.self_closing_tag_range))
else:
attn_range = []
if self.open_tag_range:
attn_range += list(range(*self.open_tag_range))
if self.end_tag_range:
attn_range += list(range(*self.end_tag_range))
return attn_range
def __repr__(self):
return f"Node(name='{self.open_tag}', idx = {self.idx})"
def print_tree(self, level=0, input_ids = None, tokenizer = None):
if level == 0:
print("--------")
indent = " " * level
if self.text:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()}, level = {level} ")
elif self.self_closing_tag:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()}, level = {level} ")
elif self.open_tag:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()}, level = {level} ")
for child in self.children:
child.print_tree(level + 1, input_ids, tokenizer)
if self.end_tag:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()}, level = {level} ")
else:
for child in self.children:
child.print_tree(level + 1, input_ids, tokenizer)
if level == 0:
print("--------")
def get_tree(self, level=0, input_ids = None, tokenizer=None):
tree_str = ""
indent = " " * level
if self.text:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()} \n"
elif self.self_closing_tag:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()} \n"
elif self.open_tag:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()} \n"
for child in self.children:
tree_str+=child.get_tree(level + 1, input_ids, tokenizer)
if self.end_tag:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()} \n"
else:
for child in self.children:
tree_str+=child.get_tree(level + 1, input_ids, tokenizer)
return tree_str
class TreeBuilder():
def __init__(self, tokenizer: AutoTokenizer = None, root: TreeNode = None, cur_node: TreeNode = None):
self.tokenizer = tokenizer
self.root = TreeNode(None, 0)
self.cur_node = self.root
self.buffer = []
self.buffer_start_index = 0
self.idx = 0
self.full_attention_list= None
self.web_attention_mask = None
self.input_ids = None
self.void_elements = [
"area",
"base",
"br",
"col",
"embed",
"hr",
"img",
"input",
"link",
"meta",
"param",
"source",
"track",
"wbr"
]
def is_empty(self):
return self.root == None
def in_buffer(self, text):
if len(self.buffer) == 0:
return False
return any(text in s for s in self.buffer)
def find_buffer(self, text):
# Iterate over the list of strings with their indices
for index, s in enumerate(self.buffer):
if text in s:
return index
return -1
# Function to extract xxx from <xxx> or <xxx yyy>
def extract_open_tag_name(self,buffer):
input_string = self.tokenizer.convert_tokens_to_string(buffer)
match = re.search(r'<\s*(\w+)(?:\s+[^>]*)?>', input_string)
if match:
return match.group(1)
return None
def extract_close_tag_name(self,buffer):
# if isinstance(input_string, list):
# input_string = "".join(input_string).replace('Ċ', '\n').replace('Ġ', ' ').replace('ĉ', '\t')
input_string = self.tokenizer.convert_tokens_to_string(buffer)
match = re.search(r'</\s*(\w+)(?:\s+[^>]*)?>', input_string)
if match:
return match.group(1)
return None
def is_not_empty_buffer(self):
return self.tokenizer.convert_tokens_to_string(self.buffer).strip() != ''
def get_parent_and_siblings_attention_range(self):
attn_range = []
if self.cur_node.parent:
parent = self.cur_node.parent
if parent.open_tag_range:
attn_range += list(range(*parent.open_tag_range))
for child in parent.children:
if child is not self.cur_node:
if child.open_tag and child.end_tag:
attn_range += list(range(*child.open_tag_range))
attn_range += list(range(*child.end_tag_range))
elif child.text:
attn_range += list(range(*child.text_range))
elif child.self_closing_tag:
attn_range += list(range(*child.self_closing_tag_range))
else:
raise Exception(f"??? line 151, get p and s attention range")
return attn_range
def update_buffer(self, cur_decoded_token):
# open tag situations
assert isinstance(cur_decoded_token,list), f"{cur_decoded_token}"
self.buffer+=cur_decoded_token
assert isinstance(cur_decoded_token[0],str)
# print(self.buffer)
try:
# dealing with end tag
if self.in_buffer('</' ) and self.in_buffer('>') and self.find_buffer('</') <= self.find_buffer('>'):
close_tag_name = self.extract_close_tag_name(self.buffer)
if self.cur_node.open_tag and not self.cur_node.end_tag:
assert close_tag_name == self.extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---"
elif self.cur_node.text or self.cur_node.self_closing_tag or self.cur_node.end_tag:
content = None
if self.cur_node.text: content = self.cur_node.text
elif self.cur_node.self_closing_tag: content = self.cur_node.self_closing_tag
elif self.cur_node.end_tag: content = self.cur_node.end_tag
self.root.print_tree(0,None,self.tokenizer)
raise Exception(f"This should never happen\n {content}, buffer is {self.buffer}")
# assert close_tag_name == extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---"
else:
raise Exception(f"having end tag without having an open tag\n {self.cur_node.text}")
self.cur_node.end_tag = self.buffer[:self.find_buffer('>')+1]
self.cur_node.end_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
self.buffer_start_index += self.find_buffer('>')+1
self.buffer = self.buffer[self.find_buffer('>')+1:]
# dealing with open tag
elif self.in_buffer('</'):
if self.cur_node.open_tag and not self.cur_node.end_tag:
pass
elif self.cur_node.text or self.cur_node.self_closing_tag or (self.cur_node.open_tag and self.cur_node.end_tag):
cur_end_tag_index = self.find_buffer('</')
# import pdb;pdb.set_trace()
if self.cur_node.text:
self.cur_node.text += self.buffer[:cur_end_tag_index]
self.cur_node.text_range[1] += len(self.buffer[:cur_end_tag_index])
elif self.cur_node.self_closing_tag:
self.cur_node.self_closing_tag += self.buffer[:cur_end_tag_index]
self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_end_tag_index])
else:
self.cur_node.end_tag += self.buffer[:cur_end_tag_index]
self.cur_node.end_tag_range[1] += len(self.buffer[:cur_end_tag_index])
self.buffer_start_index += len(self.buffer[:cur_end_tag_index])
self.buffer =self.buffer[cur_end_tag_index:]
self.cur_node = self.cur_node.parent
else:
raise Exception(f"having end tag without having an open tag\n {self.cur_node.text} {self.cur_node} {self.cur_node.parent.open_tag}")
elif self.in_buffer('<') and self.in_buffer('>'):
# in the case of self_closing tag
if self.in_buffer('/>'):
self.cur_node.open_tag = None
self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1]
self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
else:
open_tag_name = self.extract_open_tag_name(self.buffer)
if open_tag_name in self.void_elements:
self.cur_node.open_tag = None
self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1]
self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
else:
self.cur_node.open_tag = self.buffer[:self.find_buffer(">")+1]
self.cur_node.open_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
self.buffer_start_index += self.find_buffer('>')+1
self.buffer = self.buffer[self.find_buffer(">")+1:]
elif self.in_buffer('<'):
if self.full_attention_list is None:
self.full_attention_list = self.buffer[:-1]
self.buffer = self.buffer[-1:]
self.buffer_start_index = len(self.full_attention_list)
else:
cur_open_tag_index = self.find_buffer('<')
# full open tag, indicating a pair of open and close tags, or a single open tag
if not self.cur_node.partially_open() and self.cur_node.open_tag:
if self.cur_node.end_tag:
self.cur_node.end_tag += self.buffer[:cur_open_tag_index]
self.cur_node.end_tag_range[1] += len(self.buffer[:cur_open_tag_index])
self.buffer_start_index += len(self.buffer[:cur_open_tag_index])
self.buffer =self.buffer[cur_open_tag_index:]
child_node = TreeNode(self.buffer, self.idx)
if self.cur_node.parent:
self.cur_node.parent.add_child(child_node)
else:
raise Exception(f"This should never happen, a html element with full open tag should have a parent, {self.cur_node.open_tag}")
self.idx += 1
self.cur_node = child_node
else:
child_node = TreeNode(self.buffer, self.idx)
self.cur_node.add_child(child_node)
self.idx += 1
self.cur_node = child_node
elif self.cur_node.text or self.cur_node.self_closing_tag:
if self.cur_node.text:
self.cur_node.text += self.buffer[:cur_open_tag_index]
self.cur_node.text_range[1] += len(self.buffer[:cur_open_tag_index])
elif self.cur_node.self_closing_tag:
self.cur_node.self_closing_tag += self.buffer[:cur_open_tag_index]
self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_open_tag_index])
self.buffer_start_index += len(self.buffer[:cur_open_tag_index])
self.buffer =self.buffer[cur_open_tag_index:]
child_node = TreeNode(self.buffer, self.idx)
self.cur_node.parent.add_child(child_node)
self.idx += 1
self.cur_node = child_node
# if the current node has an open tag, and we are encountering texts, we create a new text node, and move down a level
elif (self.cur_node.open_tag or self.cur_node.self_closing_tag) and not self.in_buffer('<') and self.is_not_empty_buffer():
child_node = TreeNode(None, self.idx)
child_node.text = self.buffer
child_node.text_range[0] = self.buffer_start_index
child_node.text_range[1] = self.buffer_start_index + len(self.buffer)
if self.cur_node.end_tag or self.cur_node.self_closing_tag:
self.cur_node.parent.add_child(child_node)
else:
self.cur_node.add_child(child_node)
self.idx += 1
self.cur_node = child_node
self.buffer_start_index += len(self.buffer)
self.buffer = []
# if the current node does not have an open tag, but we are encountering text, we add to the exisitng text node
elif self.cur_node.text and not self.in_buffer('<') and self.is_not_empty_buffer():
self.cur_node.text += self.buffer
assert self.cur_node.text_range[0] != -1 and self.cur_node.text_range[1] != -1, f"self.cur_node.text_range[0] and [1] should not be -1 but: {self.cur_node.text_range[0]}, {self.cur_node.text_range[1]}"
self.cur_node.text_range[1] += len(self.buffer)
self.buffer_start_index += len(self.buffer)
self.buffer =[]
except Exception as e:
traceback.format_exc()
raise Exception(e)
if self.full_attention_list is None:
attn_range = list(range(len(self.buffer)))
else:
attn_range = list(range(len(self.full_attention_list))) + self.get_parent_and_siblings_attention_range() + self.cur_node.get_range() + [i + self.buffer_start_index for i in list(range(len(self.buffer)))]
return attn_range
|