Spaces:
Runtime error
Runtime error
import copy | |
import random | |
# typing | |
from typing import List, Tuple | |
import time | |
import torch | |
# TODO | |
# from transformers import LlamaTokenizer | |
# tokenizer=LlamaTokenizer.from_pretrained("/home/lyh/weights/hf/vicuna_v13/7B/") | |
TOPK = 10 # topk for sparse tree | |
from transformers.generation.logits_process import ( | |
LogitsProcessorList, | |
RepetitionPenaltyLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopKLogitsWarper, | |
TopPLogitsWarper, | |
) | |
class Timer: | |
def __init__(self,name): | |
self.name = name | |
def __enter__(self): | |
torch.cuda.synchronize() | |
self.start = time.perf_counter() | |
def __exit__(self, exc_type, exc_value, traceback): | |
torch.cuda.synchronize() | |
elapsed = time.perf_counter() - self.start | |
print(f'{self.name} took {elapsed} seconds') | |
def prepare_logits_processor( | |
temperature: float = 0.0, | |
repetition_penalty: float = 0.0, | |
top_p: float = 0.0, | |
top_k: int = 0 | |
) -> LogitsProcessorList: | |
processor_list = LogitsProcessorList() | |
if temperature > 1e-5: | |
if temperature >= 1e-5 and temperature != 1.0: | |
processor_list.append(TemperatureLogitsWarper(temperature)) | |
if repetition_penalty > 1.0: | |
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) | |
if 1e-8 <= top_p < 1.0: | |
processor_list.append(TopPLogitsWarper(top_p)) | |
if top_k > 0: | |
processor_list.append(TopKLogitsWarper(top_k)) | |
return processor_list | |
# test_processor = prepare_logits_processor( | |
# 0.0, 0.0, -1, 1 | |
# ) | |
def pad_path(path: List[int], length: int, pad_value: int = -2) -> List[int]: | |
""" | |
Pad the given path list with a specific value up to a specified length. | |
Parameters: | |
- path (list): The original list that needs padding. | |
- length (int): The desired length of the padded list. | |
- pad_value (optional, default=-2): The value to use for padding. | |
Returns: | |
- list: A new list based on the original path but padded to the desired length. | |
Example: | |
>>> pad_path([1,2,3], 5) | |
[1, 2, 3, -2, -2] | |
Note: | |
If the given path is already longer than the specified length, | |
then no padding occurs, and the original path is returned. | |
""" | |
# Calculate the number of padding values needed by subtracting the length | |
# of the path from the desired length. | |
# Append the padding values to the original path and return the new list. | |
return path + [pad_value] * (length - len(path)) | |
def generate_tree_buffers(tree_choices, device="cuda"): | |
def custom_sort(lst): | |
# sort_keys=[len(list)] | |
sort_keys = [] | |
for i in range(len(lst)): | |
sort_keys.append(lst[i] if lst[i] >= 0 else maxitem) | |
return sort_keys | |
with Timer("sort"): | |
sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x)) | |
tree_len = len(sorted_tree_choices) + 1 | |
# Initialize depth_counts to keep track of how many choices have a particular depth | |
depth_counts = [] | |
prev_depth = 0 | |
for path in sorted_tree_choices: | |
depth = len(path) | |
if depth != prev_depth: | |
depth_counts.append(0) | |
depth_counts[depth - 1] += 1 | |
prev_depth = depth | |
tree_attn_mask = torch.eye(tree_len, tree_len) | |
tree_attn_mask[:, 0] = 1 | |
start = 0 | |
for i in range(len(depth_counts)): | |
for j in range(depth_counts[i]): | |
cur_tree_choice = sorted_tree_choices[start + j] | |
# retrieve ancestor position | |
if len(cur_tree_choice) == 1: | |
continue | |
ancestor_idx = [] | |
for c in range(len(cur_tree_choice) - 1): | |
ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) | |
tree_attn_mask[j + start + 1, ancestor_idx] = 1 | |
start += depth_counts[i] | |
tree_indices = torch.zeros(tree_len, dtype=torch.long) | |
p_indices = [0 for _ in range(tree_len - 1)] | |
b_indices = [[] for _ in range(tree_len - 1)] | |
tree_indices[0] = 0 | |
start = 0 | |
bias = 0 | |
for i in range(len(depth_counts)): | |
inlayer_bias = 0 | |
b = [] | |
for j in range(depth_counts[i]): | |
cur_tree_choice = sorted_tree_choices[start + j] | |
cur_parent = cur_tree_choice[:-1] | |
if j != 0: | |
if cur_parent != parent: | |
bias += 1 | |
inlayer_bias += 1 | |
parent = cur_parent | |
b = [] | |
else: | |
parent = cur_parent | |
tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1 | |
p_indices[start + j] = inlayer_bias | |
if len(b) > 0: | |
b_indices[start + j] = copy.deepcopy(b) | |
else: | |
b_indices[start + j] = [] | |
b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1) | |
start += depth_counts[i] | |
p_indices = [-1] + p_indices | |
tree_position_ids = torch.zeros(tree_len, dtype=torch.long) | |
start = 0 | |
for i in range(len(depth_counts)): | |
tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 | |
start += depth_counts[i] | |
retrieve_indices_nest = [] | |
retrieve_paths = [] | |
for i in range(len(sorted_tree_choices)): | |
cur_tree_choice = sorted_tree_choices[-i - 1] | |
retrieve_indice = [] | |
if cur_tree_choice in retrieve_paths: | |
continue | |
else: | |
for c in range(len(cur_tree_choice)): | |
retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1])) | |
retrieve_paths.append(cur_tree_choice[:c + 1]) | |
retrieve_indices_nest.append(retrieve_indice) | |
max_length = max([len(x) for x in retrieve_indices_nest]) | |
retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] | |
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) | |
retrieve_indices = retrieve_indices + 1 | |
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], | |
dim=1) | |
maxitem = retrieve_indices.max().item() + 5 | |
retrieve_indices = retrieve_indices.tolist() | |
retrieve_indices = sorted(retrieve_indices, key=custom_sort) | |
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) | |
# Aggregate the generated buffers into a dictionary | |
tree_buffers = { | |
"tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0), | |
"tree_indices": tree_indices, | |
"tree_position_ids": tree_position_ids, | |
"retrieve_indices": retrieve_indices, | |
} | |
# Move the tensors in the dictionary to the specified device | |
tree_buffers = { | |
k: v.clone().to(device) | |
if isinstance(v, torch.Tensor) | |
else torch.tensor(v, device=device) | |
for k, v in tree_buffers.items() | |
} | |
return tree_buffers | |
def initialize_tree0(input_ids, model, past_key_values, logits_processor): | |
draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, logits, hidden_state, sample_token = model( | |
input_ids, past_key_values=past_key_values, output_orig=True, logits_processor=logits_processor | |
) | |
# if logits_processor is not None: | |
# logits = orig[:, -1] | |
# logits = logits_processor(None, logits) | |
# probabilities = torch.nn.functional.softmax(logits, dim=1) | |
# token = torch.multinomial(probabilities, 1) | |
# else: | |
# token = torch.argmax(orig[:, -1]) | |
# token = token[None, None] | |
# input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) | |
# # Clone the output hidden states | |
# | |
# draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head) | |
# if output_orig: | |
# return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token | |
# return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token | |
return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token | |
def initialize_tree(input_ids, model, past_key_values, logits_processor): | |
outputs, orig, hidden_states = model( | |
input_ids, past_key_values=past_key_values, output_orig=True | |
) | |
if logits_processor is not None: | |
logits = orig[:, -1] | |
logits = logits_processor(None, logits) | |
probabilities = torch.nn.functional.softmax(logits, dim=1) | |
token = torch.multinomial(probabilities, 1) | |
else: | |
token = torch.argmax(orig[:, -1]) | |
token = token[None, None] | |
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) | |
# Clone the output hidden states | |
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) | |
return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token | |
def reset_tree_mode( | |
model, | |
): | |
model.base_model.model.tree_mask = None | |
model.base_model.model.tree_mode = None | |
def reset_past_key_values(passed_key_values: List[torch.Tensor]) -> List[torch.Tensor]: | |
""" | |
Resets the current lengths in the passed key-values to zero. | |
This function is designed to be used during the evaluation of a baseline model. | |
It iterates through each layer's key-values and sets their current lengths to zero, | |
effectively resetting their state. | |
Args: | |
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. | |
Returns: | |
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. | |
""" | |
for i in range(len(passed_key_values)): | |
for j in range(2): | |
passed_key_values[i][j].current_length.fill_(0) | |
return passed_key_values | |
def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor): | |
sample_token = sample_token.to(tree_indices.device) | |
candidates_logit = sample_token[0] | |
candidates_tree_logits = tree_logits | |
candidates = torch.cat([candidates_logit, candidates_tree_logits.view(-1)], dim=-1) | |
tree_candidates = candidates[tree_indices] | |
tree_candidates_ext = torch.cat( | |
[tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device) - 1], dim=0) | |
cart_candidates = tree_candidates_ext[retrieve_indices] | |
# Unsqueeze the tree candidates for dimension consistency. | |
tree_candidates = tree_candidates.unsqueeze(0) | |
return cart_candidates, tree_candidates | |
def tree_decoding( | |
model, | |
tree_candidates, | |
past_key_values, | |
tree_position_ids, | |
input_ids, | |
retrieve_indices, | |
): | |
position_ids = tree_position_ids + input_ids.shape[1] | |
outputs, tree_logits, hidden_state = model( | |
tree_candidates, | |
output_orig=True, | |
past_key_values=past_key_values, | |
position_ids=position_ids, | |
) | |
logits = tree_logits[0, retrieve_indices] | |
return logits, hidden_state, outputs | |
def evaluate_posterior( | |
logits: torch.Tensor, | |
candidates: torch.Tensor, | |
logits_processor, | |
): | |
""" | |
Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. | |
Depending on the temperature value, the function either uses greedy decoding or evaluates posterior | |
probabilities to select the best candidate. | |
Args: | |
- logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). | |
- candidates (torch.Tensor): Candidate token sequences. | |
- temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. | |
- posterior_threshold (float): Threshold for posterior probability. | |
- posterior_alpha (float): Scaling factor for the threshold. | |
Returns: | |
- best_candidate (torch.Tensor): Index of the chosen best candidate. | |
- accept_length (int): Length of the accepted candidate sequence. | |
""" | |
# Greedy decoding based on temperature value | |
if logits_processor is None: | |
# Find the tokens that match the maximum logits for each position in the sequence | |
posterior_mask = ( | |
candidates[:, 1:].to(logits.device) == torch.argmax(logits[:, :-1], dim=-1) | |
).int() | |
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) | |
accept_length = candidates_accept_length.max() | |
# Choose the best candidate | |
if accept_length == 0: | |
# Default to the first candidate if none are accepted | |
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) | |
else: | |
best_candidate = torch.argmax(candidates_accept_length).to(torch.long) | |
return best_candidate, accept_length, logits[best_candidate, accept_length] | |
else: | |
accept_length = 1 | |
accept_cand = candidates[0][:1] | |
best_candidate = 0 | |
for i in range(1, candidates.shape[1]): | |
if i != accept_length: | |
break | |
adjustflag = False | |
is_eq = (candidates[:, :accept_length] == accept_cand).all(dim=1) | |
fi = torch.nonzero(is_eq, as_tuple=True)[0][0] | |
gt_logits = logits[fi, i - 1][None] | |
gt_logits = logits_processor(None, gt_logits)[0] | |
gtp = torch.softmax(gt_logits, dim=0) | |
candidates_set = [] | |
for j in range(candidates.shape[0]): | |
if is_eq[j]: | |
x = candidates[j, i] | |
xi = x.item() | |
if xi in candidates_set or xi == -1: | |
continue | |
candidates_set.append(xi) | |
r = random.random() | |
px = gtp[xi] | |
qx = 1.0 | |
acp = px / qx | |
if r <= acp: | |
accept_cand = torch.cat((accept_cand, x[None]), dim=0) | |
accept_length += 1 | |
best_candidate = j | |
break | |
else: | |
gtp[xi] = 0 | |
gtp = gtp / gtp.sum() | |
adjustflag = True | |
if adjustflag and accept_length != candidates.shape[1]: | |
sample_p = gtp | |
else: | |
gt_logits = logits[best_candidate, accept_length - 1] | |
sample_p = torch.softmax(gt_logits, dim=0) | |
return torch.tensor(best_candidate), accept_length - 1, sample_p | |
def update_inference_inputs( | |
input_ids, | |
candidates, | |
best_candidate, | |
accept_length, | |
retrieve_indices, | |
logits_processor, | |
new_token, | |
past_key_values_data_list, | |
current_length_data, | |
model, | |
hidden_state_new, | |
sample_p | |
): | |
prev_input_len = input_ids.shape[1] | |
# Map the best candidate indices to the original indices in the sequence | |
select_indices = ( | |
retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len | |
) | |
# Append the tokens from the best candidate to the input sequence | |
input_ids = torch.cat( | |
[input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1 | |
) | |
# Update the past key values based on the selected tokens | |
# Source tensor that contains relevant past information based on the selected candidate | |
for past_key_values_data in past_key_values_data_list: | |
tgt = past_key_values_data[..., select_indices.to(past_key_values_data.device), :] | |
# Destination tensor where the relevant past information will be stored | |
dst = past_key_values_data[..., prev_input_len: prev_input_len + tgt.shape[-2], :] | |
# Copy relevant past information from the source to the destination | |
dst.copy_(tgt, non_blocking=True) | |
# Update the current length tensor (currently only support batch size is 1) | |
current_length_data.fill_(prev_input_len + tgt.shape[-2]) | |
retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices] | |
accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1] | |
# token=model.base_model.lm_head(accept_hidden_state_new[:,-1]).argmax() | |
# token=token[None,None] | |
prob = sample_p | |
if logits_processor is not None: | |
token = torch.multinomial(prob, 1) | |
token = token[None] | |
else: | |
token = torch.argmax(prob) | |
token = token[None, None] | |
# hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1) | |
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new, | |
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1), | |
head=model.base_model.lm_head,logits_processor=logits_processor) | |
new_token += accept_length + 1 | |
return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token | |
if __name__ == "__main__": | |
logits = torch.randn(1, 5) | |
tp = prepare_logits_processor(0.9, 0, 0.9, 0) | |
l = tp(None, logits) | |
if tp is None: | |
print(tp) | |