Spaces:
Runtime error
Runtime error
import copy | |
import json | |
import time | |
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig,AutoConfig | |
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM | |
from .modeling_mixtral_kv import MixtralForCausalLM as KVMixtralForCausalLM | |
from .utils import * | |
from .kv_cache import initialize_past_key_values | |
from transformers import AutoTokenizer | |
import os | |
from huggingface_hub import hf_hub_download | |
from .cnets import Model | |
from .configs import EConfig | |
from huggingface_hub import hf_hub_download | |
class EaModel(nn.Module): | |
def __init__( | |
self, | |
base_model, | |
base_model_name_or_path, | |
ea_model_path, | |
total_token, | |
depth, | |
top_k, | |
threshold, | |
ea_layer_state_dict | |
): | |
super().__init__() | |
self.base_model = base_model | |
self.config = base_model.config | |
self.hidden_size = base_model.lm_head.weight.shape[-1] | |
self.vocab_size = base_model.lm_head.weight.shape[0] | |
self.base_model_name_or_path = base_model_name_or_path | |
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path,use_fast=False) | |
config = EConfig.from_pretrained(ea_model_path) | |
with open(ea_model_path,"r") as f: | |
con=json.loads(f.read()) | |
try: | |
bias=con["bias"] | |
except: | |
bias=True | |
print("draft init") | |
self.ea_layer = Model(config,bias=bias,total_tokens=total_token,depth=depth,top_k=top_k,threshold=threshold) | |
print("draft init end") | |
low_memory=False | |
device = base_model.model.layers[-1].self_attn.q_proj.weight.device | |
if device!=base_model.lm_head.weight.device: | |
self.ea_layer.diff_device = True | |
if not low_memory: | |
# self.ea_layer.head=nn.Linear(base_model.lm_head.in_features,base_model.lm_head.out_features,bias=False) | |
# self.ea_layer.head.weight=copy.deepcopy(base_model.lm_head.weight) | |
# self.ea_layer.head.to(device) | |
self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device) | |
else: | |
self.ea_layer.layer_device = device | |
else: | |
self.ea_layer.diff_device = False | |
self.ea_layer.load_state_dict(ea_layer_state_dict, strict=True) | |
self.ea_layer.to(self.base_model.dtype).to(device) | |
self.ea_layer.init_tree() | |
def get_tokenizer(self): | |
"""Get the tokenizer of the base model. | |
Returns: | |
Tokenizer: The tokenizer of the base model. | |
""" | |
return self.tokenizer | |
def from_pretrained( | |
cls, | |
Type="LLaMA", | |
base_model_path=None, | |
ea_model_path=None, | |
total_token=59, | |
depth=5, | |
top_k=10, | |
threshold=1.0, | |
**kwargs, | |
): | |
#assert Type=="LLaMA" or "Mixtral" | |
Type=AutoConfig.from_pretrained(base_model_path).architectures[0] | |
if Type=='LlamaForCausalLM': | |
base_model = KVLlamaForCausalLM.from_pretrained( | |
base_model_path, **kwargs | |
) | |
else: | |
base_model = KVMixtralForCausalLM.from_pretrained( | |
base_model_path, **kwargs | |
) | |
base_model.cuda() | |
configpath=os.path.join(ea_model_path,"config.json") | |
if not os.path.exists(configpath): | |
configpath = hf_hub_download(ea_model_path, "config.json") | |
load_model_path=os.path.join(ea_model_path, "pytorch_model.bin") | |
if not os.path.exists(load_model_path): | |
load_model_path=hf_hub_download(ea_model_path, "pytorch_model.bin") | |
ea_layer_state_dict = torch.load(load_model_path, | |
map_location="cpu") | |
model = cls( | |
base_model, | |
base_model_path, | |
configpath, | |
total_token, | |
depth, | |
top_k, | |
threshold, | |
ea_layer_state_dict | |
) | |
if total_token==-1: | |
device = model.base_model.model.layers[0].self_attn.q_proj.weight.device | |
cans=[40,48,50,56,60] | |
x=[1,1.05,1.07,1.1,1.13] | |
times=[] | |
for i in range(len(cans)): | |
length = cans[i] | |
input_ids = torch.randint(0, model.config.vocab_size - 200, (1, length)).to(device) | |
torch.cuda.synchronize() | |
start_time = time.time() | |
for _ in range(20): | |
torch.cuda.synchronize() | |
with torch.no_grad(): | |
outputs = model.base_model(input_ids) | |
torch.cuda.synchronize() | |
torch.cuda.synchronize() | |
end_time = time.time() | |
times.append((end_time - start_time) / x[i]) | |
total_token=cans[times.index(min(times))] | |
model.ea_layer.total_tokens=total_token-1 | |
return model | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
past_key_values=None, | |
output_orig=False, | |
position_ids=None, | |
): | |
with torch.inference_mode(): | |
# Pass input through the base model | |
outputs = self.base_model.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
position_ids=position_ids, | |
) | |
if output_orig: | |
orig = self.base_model.lm_head(outputs[0]) | |
hidden_states = outputs[0] | |
# if init: | |
# if logits_processor is not None: | |
# logits = orig[:, -1] | |
# logits = logits_processor(None, logits) | |
# probabilities = torch.nn.functional.softmax(logits, dim=1) | |
# token = torch.multinomial(probabilities, 1) | |
# else: | |
# token = torch.argmax(orig[:, -1]) | |
# token = token[None, None] | |
# input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) | |
# # Clone the output hidden states | |
# | |
# draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head) | |
# if output_orig: | |
# return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token | |
# return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token | |
# else: | |
if output_orig: | |
return outputs, orig, hidden_states | |
else: | |
return outputs, hidden_states | |
def eagenerate( | |
self, | |
input_ids, | |
temperature=0.0, | |
top_p=0.0, | |
top_k=0.0, | |
max_new_tokens=512, | |
max_length=2048, | |
log=False, | |
is_llama3=False, | |
): | |
if is_llama3: | |
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
max_length=max_length-self.ea_layer.total_tokens-10 | |
if temperature > 1e-5: | |
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) | |
else: | |
logits_processor = None | |
#assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
# Avoid modifying the input_ids in-place | |
padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device) | |
input_ids = input_ids.clone() | |
self.ea_layer.reset_kv() | |
# Initialize the past key and value states | |
if hasattr(self, "past_key_values"): | |
past_key_values = self.past_key_values | |
past_key_values_data = self.past_key_values_data | |
current_length_data = self.current_length_data | |
# Reset the past key and value states | |
current_length_data.zero_() | |
else: | |
( | |
past_key_values, | |
past_key_values_data, | |
current_length_data, | |
) = initialize_past_key_values(self.base_model) | |
self.past_key_values = past_key_values | |
self.past_key_values_data = past_key_values_data | |
self.current_length_data = current_length_data | |
input_len = input_ids.shape[1] | |
reset_tree_mode(self) | |
draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree( | |
input_ids, self, past_key_values, logits_processor | |
) | |
new_token = 0 | |
for idx in range(max_length): | |
#with Timer("all"): | |
self.base_model.model.tree_mask = tree_mask | |
draft_tokens=draft_tokens.to(input_ids.device) | |
#with Timer("tree_decoding"): | |
logits, hidden_state_new, outputs = tree_decoding( | |
self, | |
draft_tokens, | |
past_key_values, | |
tree_position_ids, | |
input_ids, | |
retrieve_indices, | |
) | |
#retrieve_indices=tree_buffers["retrieve_indices"] | |
#logits = logits[0, retrieve_indices] | |
draft_tokens=torch.cat((draft_tokens,padding),dim=1) | |
candidates=draft_tokens[0,retrieve_indices] | |
best_candidate, accept_length, sample_p = evaluate_posterior( | |
logits, candidates, logits_processor | |
) | |
# print(accept_length) | |
#with Timer("update_inference_inputs"): | |
input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( | |
input_ids, | |
candidates, | |
best_candidate, | |
accept_length, | |
retrieve_indices, | |
logits_processor, | |
new_token, | |
past_key_values_data, | |
current_length_data, | |
self, | |
hidden_state_new, | |
sample_p | |
) | |
if is_llama3: | |
if stop_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if new_token > max_new_tokens: | |
break | |
if input_ids.shape[1] > max_length: | |
break | |
if not log: | |
return input_ids | |
else: | |
return input_ids, new_token, idx | |
def naivegenerate( | |
self, | |
input_ids, | |
temperature=0.0, | |
top_p=0.0, | |
top_k=0.0, | |
max_new_tokens=512, | |
max_length=2048, | |
log=False, | |
is_llama3=False, | |
): | |
if is_llama3: | |
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
max_length = max_length - self.ea_layer.total_tokens - 10 | |
if temperature > 1e-5: | |
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) | |
else: | |
logits_processor = None | |
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
# Avoid modifying the input_ids in-place | |
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) | |
input_ids = input_ids.clone() | |
self.ea_layer.reset_kv() | |
# Initialize the past key and value states | |
if hasattr(self, "past_key_values"): | |
past_key_values = self.past_key_values | |
past_key_values_data = self.past_key_values_data | |
current_length_data = self.current_length_data | |
# Reset the past key and value states | |
current_length_data.zero_() | |
else: | |
( | |
past_key_values, | |
past_key_values_data, | |
current_length_data, | |
) = initialize_past_key_values(self.base_model) | |
self.past_key_values = past_key_values | |
self.past_key_values_data = past_key_values_data | |
self.current_length_data = current_length_data | |
input_len = input_ids.shape[1] | |
reset_tree_mode(self) | |
outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) | |
new_token = 0 | |
for idx in range(max_length): | |
if logits_processor is not None: | |
logits = outputs.logits[:, -1] | |
logits = logits_processor(None, logits) | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
input_id = torch.multinomial(probabilities, 1) | |
else: | |
input_id = outputs.logits[:, -1:].argmax(dim=-1) | |
outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) | |
input_ids = torch.cat([input_ids, input_id], dim=-1) | |
new_token+=1 | |
if is_llama3: | |
if stop_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if new_token > max_new_tokens: | |
break | |
if input_ids.shape[1] > max_length: | |
break | |
if not log: | |
return input_ids | |
else: | |
return input_ids, new_token, idx | |
def ea_generate( | |
self, | |
input_ids, | |
temperature=0.0, | |
top_p=0.0, | |
top_k=0.0, | |
max_new_tokens=512, | |
max_length=2048, | |
log=False, | |
is_llama3=False, | |
): | |
if is_llama3: | |
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
max_length=max_length-self.ea_layer.total_tokens-10 | |
if temperature > 1e-5: | |
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) | |
else: | |
logits_processor = None | |
#assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
# Avoid modifying the input_ids in-place | |
padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device) | |
input_ids = input_ids.clone() | |
self.ea_layer.reset_kv() | |
# Initialize the past key and value states | |
if hasattr(self, "past_key_values"): | |
past_key_values = self.past_key_values | |
past_key_values_data = self.past_key_values_data | |
current_length_data = self.current_length_data | |
# Reset the past key and value states | |
current_length_data.zero_() | |
else: | |
( | |
past_key_values, | |
past_key_values_data, | |
current_length_data, | |
) = initialize_past_key_values(self.base_model) | |
self.past_key_values = past_key_values | |
self.past_key_values_data = past_key_values_data | |
self.current_length_data = current_length_data | |
input_len = input_ids.shape[1] | |
reset_tree_mode(self) | |
draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree( | |
input_ids, self, past_key_values, logits_processor | |
) | |
new_token = 0 | |
for idx in range(max_length): | |
#with Timer("all"): | |
self.base_model.model.tree_mask = tree_mask | |
draft_tokens=draft_tokens.to(input_ids.device) | |
#with Timer("tree_decoding"): | |
logits, hidden_state_new, outputs = tree_decoding( | |
self, | |
draft_tokens, | |
past_key_values, | |
tree_position_ids, | |
input_ids, | |
retrieve_indices, | |
) | |
#retrieve_indices=tree_buffers["retrieve_indices"] | |
#logits = logits[0, retrieve_indices] | |
draft_tokens=torch.cat((draft_tokens,padding),dim=1) | |
candidates=draft_tokens[0,retrieve_indices] | |
best_candidate, accept_length, sample_p = evaluate_posterior( | |
logits, candidates, logits_processor | |
) | |
# print(accept_length) | |
#with Timer("update_inference_inputs"): | |
input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( | |
input_ids, | |
candidates, | |
best_candidate, | |
accept_length, | |
retrieve_indices, | |
logits_processor, | |
new_token, | |
past_key_values_data, | |
current_length_data, | |
self, | |
hidden_state_new, | |
sample_p | |
) | |
yield input_ids | |
if is_llama3: | |
if stop_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if new_token > max_new_tokens: | |
break | |
if input_ids.shape[1] > max_length: | |
break | |
def naive_generate( | |
self, | |
input_ids, | |
temperature=0.0, | |
top_p=0.0, | |
top_k=0.0, | |
max_new_tokens=512, | |
max_length=2048, | |
log=False, | |
is_llama3=False, | |
): | |
if is_llama3: | |
stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
max_length = max_length - self.ea_layer.total_tokens - 10 | |
if temperature > 1e-5: | |
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) | |
else: | |
logits_processor = None | |
# assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
# Avoid modifying the input_ids in-place | |
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) | |
input_ids = input_ids.clone() | |
self.ea_layer.reset_kv() | |
# Initialize the past key and value states | |
if hasattr(self, "past_key_values"): | |
past_key_values = self.past_key_values | |
past_key_values_data = self.past_key_values_data | |
current_length_data = self.current_length_data | |
# Reset the past key and value states | |
current_length_data.zero_() | |
else: | |
( | |
past_key_values, | |
past_key_values_data, | |
current_length_data, | |
) = initialize_past_key_values(self.base_model) | |
self.past_key_values = past_key_values | |
self.past_key_values_data = past_key_values_data | |
self.current_length_data = current_length_data | |
input_len = input_ids.shape[1] | |
reset_tree_mode(self) | |
outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) | |
new_token = 0 | |
for idx in range(max_length): | |
if logits_processor is not None: | |
logits = outputs.logits[:, -1] | |
logits = logits_processor(None, logits) | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
input_id = torch.multinomial(probabilities, 1) | |
else: | |
input_id = outputs.logits[:, -1:].argmax(dim=-1) | |
outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) | |
input_ids = torch.cat([input_ids, input_id], dim=-1) | |
new_token += 1 | |
yield input_ids | |
if is_llama3: | |
if stop_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
break | |
if new_token > max_new_tokens: | |
break | |
if input_ids.shape[1] > max_length: | |
break | |