import torch TOPK = 10 # topk for sparse tree def pad_path(path, length, pad_value=-2): """ Pad the given path list with a specific value up to a specified length. Parameters: - path (list): The original list that needs padding. - length (int): The desired length of the padded list. - pad_value (optional, default=-2): The value to use for padding. Returns: - list: A new list based on the original path but padded to the desired length. Example: >>> pad_path([1,2,3], 5) [1, 2, 3, -2, -2] Note: If the given path is already longer than the specified length, then no padding occurs, and the original path is returned. """ # Calculate the number of padding values needed by subtracting the length # of the path from the desired length. # Append the padding values to the original path and return the new list. return path + [pad_value] * (length - len(path)) class node: def __init__(self,parent=None,value=None,dict_key=None): self.parent=parent self.value=value if parent: self.depth=parent.depth+1 parent.children.append(self) else: self.depth=0 self.children=[] self.dict_key=dict_key def is_leaf(self): return len(self.children)==0 def all_index(self): if not self.parent.parent: return [self.index] else: return self.parent.all_index()+[self.index] class Tree: def __init__(self,tree_list): sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x)) self.root=node() self.node_dic={} for tree_node in sorted_tree_list: cur_value=tree_node[-1] if len(tree_node)==1: cur_node=node(parent=self.root,value=cur_value,dict_key=tuple(tree_node)) else: cur_parent=self.node_dic[tuple(tree_node[:-1])] cur_node = node(parent=cur_parent, value=cur_value,dict_key=tuple(tree_node)) self.node_dic[tuple(tree_node)] = cur_node self.indexnode() def max_depth(self): return max([item.depth for item in self.node_dic.values()]) def num_node_wchild(self): num_c=0 for item in self.node_dic.values(): if not item.is_leaf(): num_c+=1 return num_c def get_node_wchild(self): ns=[] for item in self.node_dic.values(): if not item.is_leaf(): ns.append(item) return ns def indexnode(self): cur_index=0 for key in self.node_dic: cur_node=self.node_dic[key] if not cur_node.is_leaf(): cur_node.index=cur_index cur_index+=1 def generate_tree_buffers(tree_choices, device="cuda"): tree=Tree(tree_choices) sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x)) tree_len = tree.num_node_wchild() max_depth=tree.max_depth() nodes_wc=tree.get_node_wchild() depth_counts=[0 for _ in range(max_depth-1)] for x in nodes_wc: depth_counts[x.depth-1]+=1 depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))] tree_attn_mask = torch.eye(tree_len, tree_len) for id,x in enumerate(nodes_wc): tree_attn_mask[id,x.all_index()]=1 tree_attn_mask_list0=[tree_attn_mask[:ml,:ml] for ml in depth_counts_sum] tree_attn_mask_list=[] for id,x in enumerate(tree_attn_mask_list0): x=x[-depth_counts[id]:] tree_attn_mask_list.append(x) tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] repeat_nums=[[] for _ in depth_counts] start = 0 bias = 0 for i in range(len(depth_counts)): bias = 0 repeat_j=0 for j in range(depth_counts[i]): cur_node = nodes_wc[start + j] cur_parent = cur_node.parent if j != 0: if cur_parent != parent: bias += 1 parent = cur_parent repeat_nums[i].append(j-repeat_j) repeat_j=j else: parent = cur_parent tree_indices_list[i][j] = cur_node.value + TOPK * (bias) repeat_nums[i].append(j - repeat_j+1) start += depth_counts[i] position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] # start = 0 # for i in range(len(depth_counts)): # position_ids[start: start + depth_counts[i]] = i # start += depth_counts[i] tree_buffers = { "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list], "tree_indices": tree_indices_list, "position_ids":position_ids, "repeat_nums":repeat_nums } # Move the tensors in the dictionary to the specified device tree_buffers = { k: [i.clone().to(device) for i in v] if isinstance(v[0], torch.Tensor) else ( torch.tensor(v, device=device) if isinstance(v, torch.Tensor) else v ) for k, v in tree_buffers.items() } return tree_buffers def reset_past_key_values(passed_key_values): """ Resets the current lengths in the passed key-values to zero. This function is designed to be used during the evaluation of a baseline model. It iterates through each layer's key-values and sets their current lengths to zero, effectively resetting their state. Args: - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. Returns: - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. """ for i in range(len(passed_key_values)): for j in range(2): passed_key_values[i][j].current_length.fill_(0) return passed_key_values if __name__=="__main__": from choices import mc_sim_7b_63 a=generate_tree_buffers(mc_sim_7b_63) print(a)