|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM |
|
from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
|
|
|
|
|
|
class AdaptiveRMSNorm(nn.Module): |
|
""" |
|
Adaptive RMSNorm layer where the scaling parameter adapts based on input. |
|
""" |
|
def __init__(self, normalized_shape, adaptive_dim, eps=1e-6): |
|
super(AdaptiveRMSNorm, self).__init__() |
|
self.normalized_shape = normalized_shape |
|
self.eps = eps |
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
|
|
|
|
self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape) |
|
|
|
def forward(self, x, adapt_input): |
|
|
|
gamma = self.fc_gamma(adapt_input).unsqueeze(1) |
|
|
|
|
|
norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps) |
|
|
|
|
|
return self.weight * norm_x * gamma |
|
|
|
class TokenMixing(nn.Module): |
|
""" |
|
Token Mixing layer that performs depthwise convolution across the sequence dimension. |
|
""" |
|
def __init__(self, hidden_size): |
|
super(TokenMixing, self).__init__() |
|
self.token_mixing = nn.Conv1d( |
|
in_channels=hidden_size, |
|
out_channels=hidden_size, |
|
kernel_size=3, |
|
padding=1, |
|
groups=hidden_size |
|
) |
|
|
|
def forward(self, x): |
|
|
|
x = x.transpose(1, 2) |
|
x = self.token_mixing(x) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
class SEBlock(nn.Module): |
|
""" |
|
Squeeze-and-Excitation block that adaptively recalibrates channel-wise features. |
|
""" |
|
def __init__(self, hidden_size, reduction=16): |
|
super(SEBlock, self).__init__() |
|
self.fc = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size // reduction, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(hidden_size // reduction, hidden_size, bias=False), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
y = x.mean(dim=1) |
|
y = self.fc(y) |
|
y = y.unsqueeze(1) |
|
return x * y |
|
|
|
class DifferentialSelfAttention(nn.Module): |
|
""" |
|
Self-Attention layer with Differential Attention Mechanism. |
|
Includes support for past_key_value and attention_mask handling. |
|
""" |
|
def __init__(self, config): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
assert self.head_dim * self.num_heads == self.hidden_size, \ |
|
"hidden_size must be divisible by num_attention_heads" |
|
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) |
|
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) |
|
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) |
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) |
|
|
|
|
|
self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1) |
|
self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1) |
|
self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1) |
|
self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1) |
|
self.lambda_init = nn.Parameter(torch.tensor(0.5)) |
|
|
|
|
|
self.sub_layer_norm = nn.LayerNorm(self.hidden_size) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
**kwargs, |
|
): |
|
batch_size, seq_length, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) * self.scaling |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
key_head_dim = key_states.size(-1) // self.num_heads |
|
key_states = key_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2) |
|
value_states = value_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2) |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
if use_cache: |
|
present_key_value = (key_states, value_states) |
|
else: |
|
present_key_value = None |
|
|
|
|
|
kv_seq_length = key_states.size(2) |
|
|
|
|
|
q1, q2 = torch.chunk(query_states, 2, dim=-1) |
|
k1, k2 = torch.chunk(key_states, 2, dim=-1) |
|
|
|
|
|
attn_scores1 = torch.matmul(q1, k1.transpose(-2, -1)) |
|
attn_scores2 = torch.matmul(q2, k2.transpose(-2, -1)) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
if attention_mask.dim() == 2: |
|
attention_mask = attention_mask[:, None, None, :] |
|
elif attention_mask.dim() == 3: |
|
attention_mask = attention_mask[:, None, :, :] |
|
attention_mask = attention_mask.to(dtype=attn_scores1.dtype) |
|
attn_scores1 += attention_mask |
|
attn_scores2 += attention_mask |
|
|
|
|
|
attn_probs1 = nn.functional.softmax(attn_scores1, dim=-1, dtype=torch.float32).to(attn_scores1.dtype) |
|
attn_probs2 = nn.functional.softmax(attn_scores2, dim=-1, dtype=torch.float32).to(attn_scores2.dtype) |
|
|
|
|
|
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)) |
|
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)) |
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init |
|
|
|
|
|
attn_probs = attn_probs1 - lambda_full * attn_probs2 |
|
|
|
|
|
attn_output = torch.matmul(attn_probs, value_states) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) |
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
attn_output = self.sub_layer_norm(attn_output) |
|
|
|
if output_attentions: |
|
|
|
attn_probs_return = attn_probs |
|
else: |
|
attn_probs_return = None |
|
|
|
return attn_output, present_key_value, attn_probs_return |
|
|
|
|
|
|
|
class ModifiedLlamaDecoderLayer(nn.Module): |
|
""" |
|
Modified Llama Decoder Layer incorporating DifferentialSelfAttention, |
|
AdaptiveRMSNorm, TokenMixing, and SEBlock. |
|
""" |
|
def __init__(self, original_layer, config): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.adaptive_dim = config.hidden_size |
|
|
|
|
|
self.self_attn = DifferentialSelfAttention(config) |
|
|
|
|
|
self.mlp = original_layer.mlp |
|
|
|
|
|
self.input_layernorm = AdaptiveRMSNorm( |
|
self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps |
|
) |
|
self.post_attention_layernorm = AdaptiveRMSNorm( |
|
self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps |
|
) |
|
|
|
|
|
self.token_mixing = TokenMixing(self.hidden_size) |
|
|
|
|
|
self.se_block = SEBlock(self.hidden_size, reduction=16) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
**kwargs, |
|
): |
|
|
|
adapt_input = hidden_states.mean(dim=1) |
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states, adapt_input) |
|
|
|
|
|
attn_output, present_key_value, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = residual + attn_output |
|
|
|
|
|
token_mixed = self.token_mixing(hidden_states) |
|
hidden_states = hidden_states + token_mixed |
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states, adapt_input) |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
hidden_states = self.se_block(hidden_states) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
|
|
class ModifiedLlamaModel(LlamaModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
ModifiedLlamaDecoderLayer(layer, config) |
|
for layer in self.layers |
|
]) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
**kwargs, |
|
): |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.use_cache |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
batch_size, seq_length = input_shape |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
batch_size, seq_length = input_shape |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.layers) |
|
|
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
if attention_mask is not None: |
|
if attention_mask.dim() == 2: |
|
attention_mask = attention_mask[:, None, None, :] |
|
elif attention_mask.dim() == 3: |
|
attention_mask = attention_mask[:, None, :, :] |
|
attention_mask = attention_mask.to(dtype=hidden_states.dtype) |
|
attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min |
|
|
|
|
|
next_decoder_cache = [] if use_cache else None |
|
all_hidden_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
for idx, (decoder_layer, layer_past) in enumerate(zip(self.layers, past_key_values)): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=layer_past, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache.append(layer_outputs[1]) |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[-1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = (hidden_states,) |
|
if use_cache: |
|
outputs += (next_decoder_cache,) |
|
if output_hidden_states: |
|
outputs += (all_hidden_states,) |
|
if output_attentions: |
|
outputs += (all_attentions,) |
|
return outputs |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_decoder_cache if use_cache else None, |
|
hidden_states=all_hidden_states if output_hidden_states else None, |
|
attentions=all_attentions if output_attentions else None, |
|
) |
|
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') |
|
|
|
|
|
modified_model = LlamaForCausalLM(config) |
|
modified_model.model = ModifiedLlamaModel(config) |
|
|
|
|
|
pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World') |
|
modified_model.load_state_dict(pretrained_model.state_dict(), strict=False) |
|
|
|
|
|
output_dir = "./BSC-LT-salamandra-2b-instruct-saved_model" |
|
modified_model.save_pretrained(output_dir) |
|
tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False) |
|
tokenizer.save_pretrained(output_dir) |
|
|
|
print(f"Model and tokenizer saved to {output_dir}") |
|
|
|
|
|
|
|
import time |
|
|
|
def chat_with_model(prompt_text, stop_token, model, tokenizer): |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
start_time = time.time() |
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt").to(device) |
|
|
|
|
|
output_sequences = model.generate( |
|
input_ids=encoded_prompt, |
|
max_new_tokens=512, |
|
temperature=0.2, |
|
repetition_penalty=1.2, |
|
top_k=30, |
|
top_p=0.9, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
eos_token_id=tokenizer.eos_token_id, |
|
use_cache=True, |
|
) |
|
|
|
|
|
generated_sequence = output_sequences[0].tolist() |
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
num_tokens = output_sequences.shape[-1] |
|
|
|
response_text = text[len(prompt_text):].strip() |
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
print(f"Total time: {total_time:.3f} seconds") |
|
tokens_per_second = num_tokens / total_time |
|
print(f"Tokens per second: {tokens_per_second:.3f}") |
|
return response_text |
|
|
|
|
|
input_text = "Hello, how are you?" |
|
stop_token = tokenizer.eos_token_id |
|
|
|
response = chat_with_model(input_text, stop_token, modified_model, tokenizer) |
|
print("Model response:", response) |
|
|
|
|