Josephgflowers commited on
Commit
eacb34e
1 Parent(s): ab7d24d

Upload LM-Diff.py

Browse files
Files changed (1) hide show
  1. LM-Diff.py +465 -0
LM-Diff.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
5
+ from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig
6
+ from transformers.modeling_outputs import BaseModelOutputWithPast
7
+
8
+ # Custom Modules
9
+
10
+ class AdaptiveRMSNorm(nn.Module):
11
+ """
12
+ Adaptive RMSNorm layer where the scaling parameter adapts based on input.
13
+ """
14
+ def __init__(self, normalized_shape, adaptive_dim, eps=1e-6):
15
+ super(AdaptiveRMSNorm, self).__init__()
16
+ self.normalized_shape = normalized_shape
17
+ self.eps = eps
18
+
19
+ # Standard RMSNorm weight parameter
20
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
21
+
22
+ # Adaptive scaling parameter
23
+ self.fc_gamma = nn.Linear(adaptive_dim, normalized_shape)
24
+
25
+ def forward(self, x, adapt_input):
26
+ # Compute adaptive scaling factor gamma
27
+ gamma = self.fc_gamma(adapt_input).unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
28
+
29
+ # Compute RMSNorm
30
+ norm_x = x / x.norm(dim=-1, keepdim=True).clamp(min=self.eps)
31
+
32
+ # Apply adaptive scaling
33
+ return self.weight * norm_x * gamma
34
+
35
+ class TokenMixing(nn.Module):
36
+ """
37
+ Token Mixing layer that performs depthwise convolution across the sequence dimension.
38
+ """
39
+ def __init__(self, hidden_size):
40
+ super(TokenMixing, self).__init__()
41
+ self.token_mixing = nn.Conv1d(
42
+ in_channels=hidden_size,
43
+ out_channels=hidden_size,
44
+ kernel_size=3,
45
+ padding=1,
46
+ groups=hidden_size # Depthwise convolution
47
+ )
48
+
49
+ def forward(self, x):
50
+ # x shape: [batch_size, seq_length, hidden_size]
51
+ x = x.transpose(1, 2) # Shape: [batch_size, hidden_size, seq_length]
52
+ x = self.token_mixing(x)
53
+ x = x.transpose(1, 2) # Shape back to [batch_size, seq_length, hidden_size]
54
+ return x
55
+
56
+ class SEBlock(nn.Module):
57
+ """
58
+ Squeeze-and-Excitation block that adaptively recalibrates channel-wise features.
59
+ """
60
+ def __init__(self, hidden_size, reduction=16):
61
+ super(SEBlock, self).__init__()
62
+ self.fc = nn.Sequential(
63
+ nn.Linear(hidden_size, hidden_size // reduction, bias=False),
64
+ nn.ReLU(inplace=True),
65
+ nn.Linear(hidden_size // reduction, hidden_size, bias=False),
66
+ nn.Sigmoid()
67
+ )
68
+
69
+ def forward(self, x):
70
+ # x shape: [batch_size, seq_length, hidden_size]
71
+ y = x.mean(dim=1) # Global average pooling over sequence length
72
+ y = self.fc(y) # Squeeze and Excitation
73
+ y = y.unsqueeze(1) # Shape: [batch_size, 1, hidden_size]
74
+ return x * y # Scale the original input
75
+
76
+ class DifferentialSelfAttention(nn.Module):
77
+ """
78
+ Self-Attention layer with Differential Attention Mechanism.
79
+ Includes support for past_key_value and attention_mask handling.
80
+ """
81
+ def __init__(self, config):
82
+ super().__init__()
83
+ self.hidden_size = config.hidden_size # e.g., 1024
84
+ self.num_heads = config.num_attention_heads # e.g., 4
85
+ self.head_dim = self.hidden_size // self.num_heads # e.g., 256
86
+ assert self.head_dim * self.num_heads == self.hidden_size, \
87
+ "hidden_size must be divisible by num_attention_heads"
88
+
89
+ self.scaling = self.head_dim ** -0.5
90
+
91
+ # Linear layers for Q, K, V projections
92
+ # Adjust k_proj and v_proj to match the pre-trained model's dimensions
93
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) # [1024, 1024]
94
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) # [1024, 256]
95
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // 8) # [1024, 256]
96
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) # [1024, 1024]
97
+
98
+ # Learnable parameters for lambda computation
99
+ self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
100
+ self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
101
+ self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
102
+ self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
103
+ self.lambda_init = nn.Parameter(torch.tensor(0.5)) # Initial value as per the paper
104
+
105
+ # Layer normalization
106
+ self.sub_layer_norm = nn.LayerNorm(self.hidden_size)
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states,
111
+ attention_mask=None,
112
+ position_ids=None,
113
+ past_key_value=None,
114
+ use_cache=False,
115
+ output_attentions=False,
116
+ **kwargs,
117
+ ):
118
+ batch_size, seq_length, _ = hidden_states.size()
119
+
120
+ # Linear projections
121
+ query_states = self.q_proj(hidden_states) * self.scaling # Shape: [batch_size, seq_length, hidden_size]
122
+ key_states = self.k_proj(hidden_states) # Shape: [batch_size, seq_length, hidden_size // 4]
123
+ value_states = self.v_proj(hidden_states) # Shape: [batch_size, seq_length, hidden_size // 4]
124
+
125
+ # Reshape and split into multiple heads
126
+ # Query states have shape: [batch_size, num_heads, seq_length, head_dim]
127
+ query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
128
+
129
+ # Key and value states have shape: [batch_size, num_heads, seq_length, key_head_dim]
130
+ key_head_dim = key_states.size(-1) // self.num_heads # Should be 256 // num_heads
131
+ key_states = key_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2)
132
+ value_states = value_states.view(batch_size, seq_length, self.num_heads, key_head_dim).transpose(1, 2)
133
+
134
+ # Handle past key values for caching
135
+ if past_key_value is not None:
136
+ # past_key_value[0] and [1] have shape (batch_size, num_heads, seq_len_prev, key_head_dim)
137
+ key_states = torch.cat([past_key_value[0], key_states], dim=2) # Concat on seq_length dimension
138
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
139
+
140
+ if use_cache:
141
+ present_key_value = (key_states, value_states)
142
+ else:
143
+ present_key_value = None
144
+
145
+ # Update sequence length after concatenation
146
+ kv_seq_length = key_states.size(2)
147
+
148
+ # Split Q and K into two groups for differential attention
149
+ q1, q2 = torch.chunk(query_states, 2, dim=-1) # Each has shape: [batch_size, num_heads, seq_length, head_dim/2]
150
+ k1, k2 = torch.chunk(key_states, 2, dim=-1) # Adjusted for key_states
151
+
152
+ # Compute attention scores
153
+ attn_scores1 = torch.matmul(q1, k1.transpose(-2, -1)) # [batch_size, num_heads, seq_length, kv_seq_length]
154
+ attn_scores2 = torch.matmul(q2, k2.transpose(-2, -1))
155
+
156
+ # Apply attention mask if provided
157
+ if attention_mask is not None:
158
+ # attention_mask should be of shape [batch_size, 1, seq_length, kv_seq_length]
159
+ if attention_mask.dim() == 2:
160
+ attention_mask = attention_mask[:, None, None, :] # Expand to [batch_size, 1, 1, kv_seq_length]
161
+ elif attention_mask.dim() == 3:
162
+ attention_mask = attention_mask[:, None, :, :]
163
+ attention_mask = attention_mask.to(dtype=attn_scores1.dtype) # Ensure dtype matches
164
+ attn_scores1 += attention_mask
165
+ attn_scores2 += attention_mask
166
+
167
+ # Compute attention probabilities
168
+ attn_probs1 = nn.functional.softmax(attn_scores1, dim=-1, dtype=torch.float32).to(attn_scores1.dtype)
169
+ attn_probs2 = nn.functional.softmax(attn_scores2, dim=-1, dtype=torch.float32).to(attn_scores2.dtype)
170
+
171
+ # Compute lambda as per the DIFF Transformer paper
172
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
173
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
174
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
175
+
176
+ # Compute differential attention
177
+ attn_probs = attn_probs1 - lambda_full * attn_probs2
178
+
179
+ # Compute attention output
180
+ attn_output = torch.matmul(attn_probs, value_states) # [batch_size, num_heads, seq_length, key_head_dim]
181
+
182
+ # Reshape and project output
183
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size)
184
+ attn_output = self.o_proj(attn_output)
185
+
186
+ # Apply layer normalization
187
+ attn_output = self.sub_layer_norm(attn_output)
188
+
189
+ if output_attentions:
190
+ # Return attention probabilities if required
191
+ attn_probs_return = attn_probs
192
+ else:
193
+ attn_probs_return = None
194
+
195
+ return attn_output, present_key_value, attn_probs_return
196
+
197
+ # Modified Decoder Layer
198
+
199
+ class ModifiedLlamaDecoderLayer(nn.Module):
200
+ """
201
+ Modified Llama Decoder Layer incorporating DifferentialSelfAttention,
202
+ AdaptiveRMSNorm, TokenMixing, and SEBlock.
203
+ """
204
+ def __init__(self, original_layer, config):
205
+ super().__init__()
206
+ self.hidden_size = config.hidden_size
207
+ self.adaptive_dim = config.hidden_size # Using hidden_size for adapt_input
208
+
209
+ # Replace the self-attention layer with DifferentialSelfAttention
210
+ self.self_attn = DifferentialSelfAttention(config)
211
+
212
+ # Copy the original MLP layer
213
+ self.mlp = original_layer.mlp
214
+
215
+ # Replace RMSNorm layers with AdaptiveRMSNorm
216
+ self.input_layernorm = AdaptiveRMSNorm(
217
+ self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps
218
+ )
219
+ self.post_attention_layernorm = AdaptiveRMSNorm(
220
+ self.hidden_size, self.adaptive_dim, eps=config.rms_norm_eps
221
+ )
222
+
223
+ # Add Token Mixing Layer
224
+ self.token_mixing = TokenMixing(self.hidden_size)
225
+
226
+ # Add SE Block
227
+ self.se_block = SEBlock(self.hidden_size, reduction=16)
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states,
232
+ attention_mask=None,
233
+ position_ids=None,
234
+ past_key_value=None,
235
+ use_cache=False,
236
+ output_attentions=False,
237
+ **kwargs,
238
+ ):
239
+ # Compute adaptation input for AdaptiveRMSNorm
240
+ adapt_input = hidden_states.mean(dim=1) # Shape: [batch_size, hidden_size]
241
+
242
+ residual = hidden_states
243
+
244
+ # Input layer normalization with adaptive RMSNorm
245
+ hidden_states = self.input_layernorm(hidden_states, adapt_input)
246
+
247
+ # Self-attention with differential attention mechanism
248
+ attn_output, present_key_value, attn_weights = self.self_attn(
249
+ hidden_states=hidden_states,
250
+ attention_mask=attention_mask,
251
+ position_ids=position_ids,
252
+ past_key_value=past_key_value,
253
+ use_cache=use_cache,
254
+ output_attentions=output_attentions,
255
+ **kwargs,
256
+ )
257
+
258
+ hidden_states = residual + attn_output
259
+
260
+ # Token Mixing
261
+ token_mixed = self.token_mixing(hidden_states)
262
+ hidden_states = hidden_states + token_mixed
263
+
264
+ # Post-attention layer normalization with adaptive RMSNorm
265
+ hidden_states = self.post_attention_layernorm(hidden_states, adapt_input)
266
+
267
+ # MLP
268
+ residual = hidden_states
269
+ hidden_states = self.mlp(hidden_states)
270
+
271
+ # SE Block
272
+ hidden_states = self.se_block(hidden_states)
273
+
274
+ hidden_states = residual + hidden_states
275
+
276
+ outputs = (hidden_states,)
277
+
278
+ if use_cache:
279
+ outputs += (present_key_value,)
280
+
281
+ if output_attentions:
282
+ outputs += (attn_weights,)
283
+
284
+ return outputs
285
+
286
+ # Modified Model
287
+
288
+ class ModifiedLlamaModel(LlamaModel):
289
+ def __init__(self, config):
290
+ super().__init__(config)
291
+
292
+ # Replace the decoder layers with modified layers
293
+ self.layers = nn.ModuleList([
294
+ ModifiedLlamaDecoderLayer(layer, config)
295
+ for layer in self.layers
296
+ ])
297
+
298
+ def forward(
299
+ self,
300
+ input_ids=None,
301
+ attention_mask=None,
302
+ position_ids=None,
303
+ past_key_values=None,
304
+ inputs_embeds=None,
305
+ use_cache=None,
306
+ output_attentions=None,
307
+ output_hidden_states=None,
308
+ return_dict=None,
309
+ **kwargs, # Capture any additional keyword arguments
310
+ ):
311
+ # Ensure default values are set
312
+ output_attentions = output_attentions if output_attentions is not None else self.config.use_cache
313
+ output_hidden_states = (
314
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
315
+ )
316
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
317
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318
+
319
+ # Process inputs
320
+ if input_ids is not None and inputs_embeds is not None:
321
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
322
+ elif input_ids is not None:
323
+ input_shape = input_ids.size()
324
+ batch_size, seq_length = input_shape
325
+ elif inputs_embeds is not None:
326
+ input_shape = inputs_embeds.size()[:-1]
327
+ batch_size, seq_length = input_shape
328
+ else:
329
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
330
+
331
+ # Initialize past_key_values if not provided
332
+ if past_key_values is None:
333
+ past_key_values = [None] * len(self.layers)
334
+
335
+ # Embed tokens
336
+ if inputs_embeds is None:
337
+ inputs_embeds = self.embed_tokens(input_ids)
338
+
339
+ hidden_states = inputs_embeds
340
+
341
+ # Attention mask processing
342
+ if attention_mask is not None:
343
+ if attention_mask.dim() == 2:
344
+ attention_mask = attention_mask[:, None, None, :]
345
+ elif attention_mask.dim() == 3:
346
+ attention_mask = attention_mask[:, None, :, :]
347
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype)
348
+ attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
349
+
350
+ # Main loop over layers
351
+ next_decoder_cache = [] if use_cache else None
352
+ all_hidden_states = () if output_hidden_states else None
353
+ all_attentions = () if output_attentions else None
354
+
355
+ for idx, (decoder_layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
356
+ if output_hidden_states:
357
+ all_hidden_states = all_hidden_states + (hidden_states,)
358
+
359
+ # Forward pass through the layer
360
+ layer_outputs = decoder_layer(
361
+ hidden_states,
362
+ attention_mask=attention_mask,
363
+ position_ids=position_ids,
364
+ past_key_value=layer_past,
365
+ use_cache=use_cache,
366
+ output_attentions=output_attentions,
367
+ **kwargs, # Pass any additional keyword arguments
368
+ )
369
+
370
+ hidden_states = layer_outputs[0]
371
+
372
+ if use_cache:
373
+ next_decoder_cache.append(layer_outputs[1])
374
+
375
+ if output_attentions:
376
+ all_attentions = all_attentions + (layer_outputs[-1],)
377
+
378
+ hidden_states = self.norm(hidden_states)
379
+
380
+ if output_hidden_states:
381
+ all_hidden_states = all_hidden_states + (hidden_states,)
382
+
383
+ if not return_dict:
384
+ outputs = (hidden_states,)
385
+ if use_cache:
386
+ outputs += (next_decoder_cache,)
387
+ if output_hidden_states:
388
+ outputs += (all_hidden_states,)
389
+ if output_attentions:
390
+ outputs += (all_attentions,)
391
+ return outputs
392
+
393
+ return BaseModelOutputWithPast(
394
+ last_hidden_state=hidden_states,
395
+ past_key_values=next_decoder_cache if use_cache else None,
396
+ hidden_states=all_hidden_states if output_hidden_states else None,
397
+ attentions=all_attentions if output_attentions else None,
398
+ )
399
+
400
+ # Load the pre-trained model
401
+
402
+ # Load the configuration from the pre-trained model
403
+ config = AutoConfig.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
404
+
405
+ # Initialize the modified model
406
+ modified_model = LlamaForCausalLM(config)
407
+ modified_model.model = ModifiedLlamaModel(config)
408
+
409
+ # Load the pre-trained weights
410
+ pretrained_model = LlamaForCausalLM.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World')
411
+ modified_model.load_state_dict(pretrained_model.state_dict(), strict=False)
412
+
413
+ # Save the model and tokenizer
414
+ output_dir = "./BSC-LT-salamandra-2b-instruct-saved_model"
415
+ modified_model.save_pretrained(output_dir)
416
+ tokenizer = AutoTokenizer.from_pretrained('Josephgflowers/TinyLlama-v1.1-Cinders-World', legacy=False)
417
+ tokenizer.save_pretrained(output_dir)
418
+
419
+ print(f"Model and tokenizer saved to {output_dir}")
420
+
421
+ # Example Usage
422
+
423
+ import time
424
+
425
+ def chat_with_model(prompt_text, stop_token, model, tokenizer):
426
+ # Encode the prompt text
427
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
428
+ model.to(device)
429
+ start_time = time.time()
430
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt").to(device)
431
+
432
+ # Generate response
433
+ output_sequences = model.generate(
434
+ input_ids=encoded_prompt,
435
+ max_new_tokens=512,
436
+ temperature=0.2,
437
+ repetition_penalty=1.2,
438
+ top_k=30,
439
+ top_p=0.9,
440
+ do_sample=True,
441
+ num_return_sequences=1,
442
+ eos_token_id=tokenizer.eos_token_id,
443
+ use_cache=True, # Ensure use_cache is True for generation
444
+ )
445
+
446
+ # Decode the generated sequence
447
+ generated_sequence = output_sequences[0].tolist()
448
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
449
+ num_tokens = output_sequences.shape[-1]
450
+
451
+ response_text = text[len(prompt_text):].strip()
452
+ end_time = time.time()
453
+ total_time = end_time - start_time
454
+ print(f"Total time: {total_time:.3f} seconds")
455
+ tokens_per_second = num_tokens / total_time
456
+ print(f"Tokens per second: {tokens_per_second:.3f}")
457
+ return response_text
458
+
459
+ # Example usage
460
+ input_text = "Hello, how are you?"
461
+ stop_token = tokenizer.eos_token_id # Assuming EOS token as the stop token
462
+
463
+ response = chat_with_model(input_text, stop_token, modified_model, tokenizer)
464
+ print("Model response:", response)
465
+