import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Dict, Tuple, Optional, NamedTuple from enum import Enum, auto from dataclasses import dataclass import warnings warnings.filterwarnings("ignore", category=FutureWarning) class DecoderState(Enum): GREEDY_UNTIL_NEWLINE = auto() SELECT_AFTER_NEWLINE = auto() TERMINATED = auto() class CacheState(NamedTuple): past_key_values: Tuple last_position: int @dataclass class GenerationState: tokens: torch.Tensor attention_mask: torch.Tensor cache_state: Optional[CacheState] = None entropy_diffs: List[float] = None current_length: int = 0 _token_buffer: Optional[torch.Tensor] = None _attn_buffer: Optional[torch.Tensor] = None def __post_init__(self): self.entropy_diffs = [] # Pre-allocate buffers for token and attention mask growth max_length = self.tokens.size(1) + 1024 # reasonable buffer size self._token_buffer = torch.zeros( (1, max_length), dtype=self.tokens.dtype, device=self.tokens.device ) self._attn_buffer = torch.ones( (1, max_length), dtype=self.attention_mask.dtype, device=self.attention_mask.device ) # Copy initial tokens and attention mask self._token_buffer[:, :self.tokens.size(1)] = self.tokens self._attn_buffer[:, :self.attention_mask.size(1)] = self.attention_mask def extend(self, new_token: torch.Tensor): """Efficient in-place extension of state""" current_len = self.tokens.size(1) if len(new_token.shape) == 0: new_token = new_token.unsqueeze(0) # Use pre-allocated buffers self._token_buffer[:, current_len] = new_token self.tokens = self._token_buffer[:, :current_len + 1] self.attention_mask = self._attn_buffer[:, :current_len + 1] self.current_length += 1 class SpeculativeDecoder: def __init__( self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: Optional[torch.device] = None, max_new_tokens: int = 512, k: int = 3, use_cache: bool = True ): self.model = model self.tokenizer = tokenizer self.device = device or next(model.parameters()).device self.max_new_tokens = max_new_tokens self.k = k self.use_cache = use_cache # Pre-compute constants self.newline_token = tokenizer.encode("\n", add_special_tokens=False)[0] if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id # Pre-allocate reusable tensors self.batch_attention_mask = torch.ones(k, 1, dtype=torch.long, device=self.device) # Prepare model for inference if hasattr(model, 'eval'): model.eval() # Enable Flash Attention if available if hasattr(model, 'enable_flash_attention'): try: model.enable_flash_attention() except Exception as e: warnings.warn(f"Failed to enable Flash Attention: {e}") @staticmethod @torch.jit.script def calculate_entropy(probs: torch.Tensor) -> torch.Tensor: """JIT-compiled entropy calculation""" return -torch.sum(probs * torch.log2(probs + 1e-12), dim=-1) def set_k(self, k: int): self.k = k self.batch_attention_mask = torch.ones(k, 1, dtype=torch.long, device=self.device) def prepare_inputs(self, messages: List[Dict[str, str]]) -> torch.Tensor: """Efficient input preparation""" if hasattr(self.tokenizer, 'chat_template'): input_text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: input_text = "\n".join(f"{msg['role']}: {msg['content']}" for msg in messages) + "\nassistant:" return self.tokenizer( input_text, return_tensors="pt", padding=False ).input_ids.to(self.device) def select_least_entropic_token(self, state: GenerationState) -> Tuple[torch.Tensor, float]: """Optimized token selection with vectorized operations""" with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): # Initial logits computation outputs = self.model( input_ids=state.tokens[:, -1:] if state.cache_state else state.tokens, attention_mask=state.attention_mask, past_key_values=state.cache_state.past_key_values if state.cache_state else None, use_cache=True ) state.cache_state = CacheState(outputs.past_key_values, state.tokens.size(1)) if self.use_cache else None # Efficient top-k selection logits = outputs.logits[0, -1] top_k_probs, top_k_indices = torch.topk(F.softmax(logits, dim=-1), self.k) # Prepare batch inputs batch_tokens = top_k_indices.unsqueeze(1) # Efficient cache expansion if state.cache_state: batch_past_kv = [ ( layer_past[0].expand(self.k, -1, -1, -1), layer_past[1].expand(self.k, -1, -1, -1) ) for layer_past in state.cache_state.past_key_values ] else: batch_past_kv = None # Single forward pass for all candidates batch_outputs = self.model( input_ids=batch_tokens, attention_mask=self.batch_attention_mask, past_key_values=batch_past_kv, use_cache=True, output_attentions=True ) # Efficient attention processing middle_layer = len(batch_outputs.attentions) // 2 batch_attn_probs = F.softmax( batch_outputs.attentions[middle_layer][:, :, -1, :], dim=-1 ) # Vectorized entropy calculation old_entropy = self.calculate_entropy(batch_attn_probs[:, :, :-1]) new_entropy = self.calculate_entropy(batch_attn_probs) # Efficient difference calculation entropy_var = torch.var( torch.stack([old_entropy, new_entropy]), dim=0, keepdim=True ) + 1e-6 diffs = ((new_entropy - old_entropy) / entropy_var).mean(dim=-1).squeeze(0) min_idx = diffs.argmin() return top_k_indices[min_idx].unsqueeze(0), diffs[min_idx].item() def greedy_decode(self, state: GenerationState) -> torch.Tensor: """Optimized greedy decoding""" with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): outputs = self.model( input_ids=state.tokens[:, -1:] if state.cache_state else state.tokens, attention_mask=state.attention_mask, past_key_values=state.cache_state.past_key_values if state.cache_state else None, use_cache=True ) state.cache_state = CacheState( outputs.past_key_values, state.tokens.size(1) ) if self.use_cache else None return outputs.logits[0, -1].argmax() def __call__( self, messages: List[Dict[str, str]] ) -> Tuple[str, float]: """Main decoding loop with optimized state transitions""" input_ids = self.prepare_inputs(messages) state = GenerationState( tokens=input_ids, attention_mask=torch.ones_like(input_ids) ) current_state = DecoderState.SELECT_AFTER_NEWLINE while current_state != DecoderState.TERMINATED and state.current_length < self.max_new_tokens: if current_state == DecoderState.SELECT_AFTER_NEWLINE: next_token, entropy_diff = self.select_least_entropic_token(state) state.entropy_diffs.append(entropy_diff) current_state = DecoderState.GREEDY_UNTIL_NEWLINE else: # GREEDY_UNTIL_NEWLINE next_token = self.greedy_decode(state) if next_token.item() == self.tokenizer.eos_token_id: current_state = DecoderState.TERMINATED elif next_token.item() == self.newline_token: current_state = DecoderState.SELECT_AFTER_NEWLINE state.extend(next_token) # Efficient post-processing generated_ids = state.tokens[0, len(input_ids[0]):] generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) # Vectorized score calculation if state.entropy_diffs: avg_entropy_diff = torch.tensor(state.entropy_diffs).mean().item() else: avg_entropy_diff = 1.0 completion_ratio = len(generated_ids) / self.max_new_tokens score = (1.0 / (avg_entropy_diff/100 + 1e-12)) * completion_ratio return generated_text, round(score ** 0.33, 4)