import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM from .utils import * from .kv_cache import initialize_past_key_values from .choices import mc_sim_7b_63 from transformers import AutoTokenizer import os from huggingface_hub import hf_hub_download from .cnets import Model from .configs import EConfig class ResBlock(nn.Module): """ A Residual Block module. This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection. Args: hidden_size (int): The size of the hidden layers in the block. """ def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) # Initialize as an identity mapping torch.nn.init.zeros_(self.linear.weight) # Use SiLU activation to keep consistent with the Llama model self.act = nn.SiLU() def forward(self, x): """ Forward pass of the ResBlock. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output after the residual connection and activation. """ return x + self.act(self.linear(x)) class EaModel(nn.Module): def __init__( self, base_model, base_model_name_or_path, ea_model_path, ): super().__init__() self.base_model = base_model self.config = base_model.config self.hidden_size = base_model.lm_head.weight.shape[-1] self.vocab_size = base_model.lm_head.weight.shape[0] self.base_model_name_or_path = base_model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) config = EConfig.from_pretrained(ea_model_path) self.ea_layer = Model(config) device = base_model.model.layers[-1].self_attn.q_proj.weight.device self.ea_layer.to(torch.float16).to(device) self.ea_layer.init_tree() def get_tokenizer(self): """Get the tokenizer of the base model. Returns: Tokenizer: The tokenizer of the base model. """ return self.tokenizer @classmethod def from_pretrained( cls, base_model_path=None, ea_model_path=None, **kwargs, ): base_model = KVLlamaForCausalLM.from_pretrained( base_model_path, **kwargs ) model = cls( base_model, base_model_path, ea_model_path ) ea_layer_state_dict = torch.load(os.path.join(ea_model_path,"pytorch_model.bin"), map_location=base_model.device) model.ea_layer.load_state_dict(ea_layer_state_dict, strict=False) return model def forward( self, input_ids=None, attention_mask=None, labels=None, past_key_values=None, output_orig=False, position_ids=None, init=True, logits_processor=None ): with torch.inference_mode(): # Pass input through the base model outputs = self.base_model.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, ) if output_orig: orig = self.base_model.lm_head(outputs[0]) hidden_states = outputs[0].clone() if init: if logits_processor is not None: logits=orig[:, -1] logits=logits_processor(None,logits) probabilities = torch.nn.functional.softmax(logits, dim=1) token=torch.multinomial(probabilities, 1) else: token = torch.argmax(orig[:,-1]) token=token[None,None] input_ids=torch.cat((input_ids,token.to(input_ids.device)),dim=1) # Clone the output hidden states ea_logits = self.ea_layer.topK_genrate(hidden_states,input_ids,self.base_model.lm_head,logits_processor) if output_orig: return ea_logits, outputs, orig,hidden_states,token return ea_logits,hidden_states,token else: if output_orig: return outputs,orig,hidden_states @torch.no_grad() def eagenerate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_new_tokens=512, max_length=2048, tree_choices=mc_sim_7b_63, ): if temperature>1e-5: logits_processor=prepare_logits_processor(temperature=temperature,top_p=top_p,top_k=top_k) else: logits_processor=None assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place input_ids = input_ids.clone() self.ea_layer.reset_kv() if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: tree_buffers = self.tree_buffers else: tree_buffers = generate_tree_buffers( tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device ) self.tree_buffers = tree_buffers self.tree_choices = tree_choices # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) tree_logits, logits, hidden_state, sample_token = initialize_tree( input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor ) new_token = 0 for idx in range(max_length): candidates, cart_candidates_prob, tree_candidates = generate_candidates( tree_logits, tree_buffers["tree_indices"], tree_buffers["retrieve_indices"], sample_token, logits_processor ) logits, hidden_state_new, outputs = tree_decoding( self, tree_candidates, past_key_values, tree_buffers["tree_position_ids"], input_ids, tree_buffers["retrieve_indices"], ) best_candidate, accept_length, sample_p = evaluate_posterior( logits, candidates, logits_processor, cart_candidates_prob ) input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs( input_ids, candidates, best_candidate, accept_length, tree_buffers["retrieve_indices"], logits_processor, logits, tree_logits, new_token, past_key_values_data, current_length_data, self, hidden_state, hidden_state_new, sample_p ) if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): return input_ids if new_token > max_new_tokens: return input_ids if input_ids.shape[1] > max_length: return input_ids @torch.no_grad() def ea_generate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_steps=512, tree_choices=mc_sim_7b_63, ): if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place input_ids = input_ids.clone() self.ea_layer.reset_kv() if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: tree_buffers = self.tree_buffers else: tree_buffers = generate_tree_buffers( tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device ) self.tree_buffers = tree_buffers self.tree_choices = tree_choices # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) tree_logits, logits, hidden_state, sample_token = initialize_tree( input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor ) new_token = 0 for idx in range(max_steps): candidates, cart_candidates_prob, tree_candidates = generate_candidates( tree_logits, tree_buffers["tree_indices"], tree_buffers["retrieve_indices"], sample_token, logits_processor ) logits, hidden_state_new, outputs = tree_decoding( self, tree_candidates, past_key_values, tree_buffers["tree_position_ids"], input_ids, tree_buffers["retrieve_indices"], ) best_candidate, accept_length, sample_p = evaluate_posterior( logits, candidates, logits_processor, cart_candidates_prob ) input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs( input_ids, candidates, best_candidate, accept_length, tree_buffers["retrieve_indices"], logits_processor, logits, tree_logits, new_token, past_key_values_data, current_length_data, self, hidden_state, hidden_state_new, sample_p ) yield input_ids if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > 1024: break if input_ids.shape[1] > 1960: break @torch.no_grad() def naive_generate( self, input_ids, temperature=0.0, top_p=0.0, top_k=0.0, max_steps=512, tree_choices=mc_sim_7b_63, ): if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" # Avoid modifying the input_ids in-place input_ids = input_ids.clone() self.ea_layer.reset_kv() if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: tree_buffers = self.tree_buffers else: tree_buffers = generate_tree_buffers( tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device ) self.tree_buffers = tree_buffers self.tree_choices = tree_choices # Initialize the past key and value states if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data # Reset the past key and value states current_length_data.zero_() else: ( past_key_values, past_key_values_data, current_length_data, ) = initialize_past_key_values(self.base_model) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data input_len = input_ids.shape[1] reset_tree_mode(self) outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) new_token = 0 for idx in range(max_steps): input_id = outputs.logits[:, -1:].argmax(dim=-1) outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) input_ids = torch.cat([input_ids, input_id], dim=-1) yield input_ids if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break if new_token > 1024: break if input_ids.shape[1] > 1960: break