Safetensors
vmistral
custom_code
Waffle_VLM_WebSight / generation_utils.py
jiang719's picture
Upload folder using huggingface_hub
2fed580 verified
from typing import Any, Dict, Optional, List
import torch
from transformers import GenerationMixin
from transformers import AutoTokenizer
import re
import traceback
class WebGenerationMixin(GenerationMixin):
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if not is_encoder_decoder:
# update attention mask
if 'web_attention_mask' not in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs['web_attention_mask'] = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0)
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
model_kwargs['html_tree'] = outputs.html_tree
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
return model_kwargs
def _reorder_cache(self, past_key_values, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
f" enable beam search for {self.__class__}"
)
class TreeNode():
def __init__(self,content: list, idx: int):
self.open_tag: List[str] = content
self.end_tag: Optional[List[str]] = None
self.self_closing_tag: Optional[List[str]] = None
self.text = ""
self.name: Optional[str] = None
self.parent: Optional['TreeNode'] = None # Use 'TreeNode' as a string for forward reference
self.open_tag_range: Optional[List[int]] = None
self.end_tag_range: Optional[List[int]] = None
self.text_range = [-1,-1]
self.self_closing_tag_range = [-1,-1]
self.idx: int = idx
self.children: List['TreeNode'] = [] # List of TreeNode instances
def partially_open(self):
if not self.open_tag: return False
if any('<' in s for s in self.open_tag) and not any('>' in s for s in self.open_tag):
return True
return False
def add_child(self,child):
assert child.parent is None, "Child already has a parent"
assert child not in self.children, "Child is already in children list"
child.parent = self
self.children.append(child)
def get_range(self):
if self.text:
return list(range(*self.text_range))
elif self.self_closing_tag:
return list(range(*self.self_closing_tag_range))
else:
attn_range = []
if self.open_tag_range:
attn_range += list(range(*self.open_tag_range))
if self.end_tag_range:
attn_range += list(range(*self.end_tag_range))
return attn_range
def __repr__(self):
return f"Node(name='{self.open_tag}', idx = {self.idx})"
def print_tree(self, level=0, input_ids = None, tokenizer = None):
if level == 0:
print("--------")
indent = " " * level
if self.text:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()}, level = {level} ")
elif self.self_closing_tag:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()}, level = {level} ")
elif self.open_tag:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()}, level = {level} ")
for child in self.children:
child.print_tree(level + 1, input_ids, tokenizer)
if self.end_tag:
print(f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()}, level = {level} ")
else:
for child in self.children:
child.print_tree(level + 1, input_ids, tokenizer)
if level == 0:
print("--------")
def get_tree(self, level=0, input_ids = None, tokenizer=None):
tree_str = ""
indent = " " * level
if self.text:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()} \n"
elif self.self_closing_tag:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()} \n"
elif self.open_tag:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()} \n"
for child in self.children:
tree_str+=child.get_tree(level + 1, input_ids, tokenizer)
if self.end_tag:
tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()} \n"
else:
for child in self.children:
tree_str+=child.get_tree(level + 1, input_ids, tokenizer)
return tree_str
class TreeBuilder():
def __init__(self, tokenizer: AutoTokenizer = None, root: TreeNode = None, cur_node: TreeNode = None):
self.tokenizer = tokenizer
self.root = TreeNode(None, 0)
self.cur_node = self.root
self.buffer = []
self.buffer_start_index = 0
self.idx = 0
self.full_attention_list= None
self.web_attention_mask = None
self.input_ids = None
self.void_elements = [
"area",
"base",
"br",
"col",
"embed",
"hr",
"img",
"input",
"link",
"meta",
"param",
"source",
"track",
"wbr"
]
def is_empty(self):
return self.root == None
def in_buffer(self, text):
if len(self.buffer) == 0:
return False
return any(text in s for s in self.buffer)
def find_buffer(self, text):
# Iterate over the list of strings with their indices
for index, s in enumerate(self.buffer):
if text in s:
return index
return -1
# Function to extract xxx from <xxx> or <xxx yyy>
def extract_open_tag_name(self,buffer):
input_string = self.tokenizer.convert_tokens_to_string(buffer)
match = re.search(r'<\s*(\w+)(?:\s+[^>]*)?>', input_string)
if match:
return match.group(1)
return None
def extract_close_tag_name(self,buffer):
# if isinstance(input_string, list):
# input_string = "".join(input_string).replace('Ċ', '\n').replace('Ġ', ' ').replace('ĉ', '\t')
input_string = self.tokenizer.convert_tokens_to_string(buffer)
match = re.search(r'</\s*(\w+)(?:\s+[^>]*)?>', input_string)
if match:
return match.group(1)
return None
def is_not_empty_buffer(self):
return self.tokenizer.convert_tokens_to_string(self.buffer).strip() != ''
def get_parent_and_siblings_attention_range(self):
attn_range = []
if self.cur_node.parent:
parent = self.cur_node.parent
if parent.open_tag_range:
attn_range += list(range(*parent.open_tag_range))
for child in parent.children:
if child is not self.cur_node:
if child.open_tag and child.end_tag:
attn_range += list(range(*child.open_tag_range))
attn_range += list(range(*child.end_tag_range))
elif child.text:
attn_range += list(range(*child.text_range))
elif child.self_closing_tag:
attn_range += list(range(*child.self_closing_tag_range))
else:
raise Exception(f"??? line 151, get p and s attention range")
return attn_range
def update_buffer(self, cur_decoded_token):
# open tag situations
assert isinstance(cur_decoded_token,list), f"{cur_decoded_token}"
self.buffer+=cur_decoded_token
assert isinstance(cur_decoded_token[0],str)
# print(self.buffer)
try:
# dealing with end tag
if self.in_buffer('</' ) and self.in_buffer('>') and self.find_buffer('</') <= self.find_buffer('>'):
close_tag_name = self.extract_close_tag_name(self.buffer)
if self.cur_node.open_tag and not self.cur_node.end_tag:
assert close_tag_name == self.extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---"
elif self.cur_node.text or self.cur_node.self_closing_tag or self.cur_node.end_tag:
content = None
if self.cur_node.text: content = self.cur_node.text
elif self.cur_node.self_closing_tag: content = self.cur_node.self_closing_tag
elif self.cur_node.end_tag: content = self.cur_node.end_tag
self.root.print_tree(0,None,self.tokenizer)
raise Exception(f"This should never happen\n {content}, buffer is {self.buffer}")
# assert close_tag_name == extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---"
else:
raise Exception(f"having end tag without having an open tag\n {self.cur_node.text}")
self.cur_node.end_tag = self.buffer[:self.find_buffer('>')+1]
self.cur_node.end_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
self.buffer_start_index += self.find_buffer('>')+1
self.buffer = self.buffer[self.find_buffer('>')+1:]
# dealing with open tag
elif self.in_buffer('</'):
if self.cur_node.open_tag and not self.cur_node.end_tag:
pass
elif self.cur_node.text or self.cur_node.self_closing_tag or (self.cur_node.open_tag and self.cur_node.end_tag):
cur_end_tag_index = self.find_buffer('</')
# import pdb;pdb.set_trace()
if self.cur_node.text:
self.cur_node.text += self.buffer[:cur_end_tag_index]
self.cur_node.text_range[1] += len(self.buffer[:cur_end_tag_index])
elif self.cur_node.self_closing_tag:
self.cur_node.self_closing_tag += self.buffer[:cur_end_tag_index]
self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_end_tag_index])
else:
self.cur_node.end_tag += self.buffer[:cur_end_tag_index]
self.cur_node.end_tag_range[1] += len(self.buffer[:cur_end_tag_index])
self.buffer_start_index += len(self.buffer[:cur_end_tag_index])
self.buffer =self.buffer[cur_end_tag_index:]
self.cur_node = self.cur_node.parent
else:
raise Exception(f"having end tag without having an open tag\n {self.cur_node.text} {self.cur_node} {self.cur_node.parent.open_tag}")
elif self.in_buffer('<') and self.in_buffer('>'):
# in the case of self_closing tag
if self.in_buffer('/>'):
self.cur_node.open_tag = None
self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1]
self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
else:
open_tag_name = self.extract_open_tag_name(self.buffer)
if open_tag_name in self.void_elements:
self.cur_node.open_tag = None
self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1]
self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
else:
self.cur_node.open_tag = self.buffer[:self.find_buffer(">")+1]
self.cur_node.open_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1]
self.buffer_start_index += self.find_buffer('>')+1
self.buffer = self.buffer[self.find_buffer(">")+1:]
elif self.in_buffer('<'):
if self.full_attention_list is None:
self.full_attention_list = self.buffer[:-1]
self.buffer = self.buffer[-1:]
self.buffer_start_index = len(self.full_attention_list)
else:
cur_open_tag_index = self.find_buffer('<')
# full open tag, indicating a pair of open and close tags, or a single open tag
if not self.cur_node.partially_open() and self.cur_node.open_tag:
if self.cur_node.end_tag:
self.cur_node.end_tag += self.buffer[:cur_open_tag_index]
self.cur_node.end_tag_range[1] += len(self.buffer[:cur_open_tag_index])
self.buffer_start_index += len(self.buffer[:cur_open_tag_index])
self.buffer =self.buffer[cur_open_tag_index:]
child_node = TreeNode(self.buffer, self.idx)
if self.cur_node.parent:
self.cur_node.parent.add_child(child_node)
else:
raise Exception(f"This should never happen, a html element with full open tag should have a parent, {self.cur_node.open_tag}")
self.idx += 1
self.cur_node = child_node
else:
child_node = TreeNode(self.buffer, self.idx)
self.cur_node.add_child(child_node)
self.idx += 1
self.cur_node = child_node
elif self.cur_node.text or self.cur_node.self_closing_tag:
if self.cur_node.text:
self.cur_node.text += self.buffer[:cur_open_tag_index]
self.cur_node.text_range[1] += len(self.buffer[:cur_open_tag_index])
elif self.cur_node.self_closing_tag:
self.cur_node.self_closing_tag += self.buffer[:cur_open_tag_index]
self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_open_tag_index])
self.buffer_start_index += len(self.buffer[:cur_open_tag_index])
self.buffer =self.buffer[cur_open_tag_index:]
child_node = TreeNode(self.buffer, self.idx)
self.cur_node.parent.add_child(child_node)
self.idx += 1
self.cur_node = child_node
# if the current node has an open tag, and we are encountering texts, we create a new text node, and move down a level
elif (self.cur_node.open_tag or self.cur_node.self_closing_tag) and not self.in_buffer('<') and self.is_not_empty_buffer():
child_node = TreeNode(None, self.idx)
child_node.text = self.buffer
child_node.text_range[0] = self.buffer_start_index
child_node.text_range[1] = self.buffer_start_index + len(self.buffer)
if self.cur_node.end_tag or self.cur_node.self_closing_tag:
self.cur_node.parent.add_child(child_node)
else:
self.cur_node.add_child(child_node)
self.idx += 1
self.cur_node = child_node
self.buffer_start_index += len(self.buffer)
self.buffer = []
# if the current node does not have an open tag, but we are encountering text, we add to the exisitng text node
elif self.cur_node.text and not self.in_buffer('<') and self.is_not_empty_buffer():
self.cur_node.text += self.buffer
assert self.cur_node.text_range[0] != -1 and self.cur_node.text_range[1] != -1, f"self.cur_node.text_range[0] and [1] should not be -1 but: {self.cur_node.text_range[0]}, {self.cur_node.text_range[1]}"
self.cur_node.text_range[1] += len(self.buffer)
self.buffer_start_index += len(self.buffer)
self.buffer =[]
except Exception as e:
traceback.format_exc()
raise Exception(e)
if self.full_attention_list is None:
attn_range = list(range(len(self.buffer)))
else:
attn_range = list(range(len(self.full_attention_list))) + self.get_parent_and_siblings_attention_range() + self.cur_node.get_range() + [i + self.buffer_start_index for i in list(range(len(self.buffer)))]
return attn_range