bjoernp commited on
Commit
96e2479
1 Parent(s): 12fbf06

Update modeling_bitllama.py

Browse files
Files changed (1) hide show
  1. modeling_bitllama.py +17 -8
modeling_bitllama.py CHANGED
@@ -28,17 +28,21 @@ import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
 
31
- from ...activations import ACT2FN
32
- from ...cache_utils import Cache, DynamicCache
33
- from ...modeling_attn_mask_utils import (
34
  AttentionMaskConverter,
35
  _prepare_4d_attention_mask,
36
  _prepare_4d_causal_attention_mask,
37
  )
38
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
39
- from ...modeling_utils import PreTrainedModel
40
- from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
41
- from ...utils import (
 
 
 
 
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
  is_flash_attn_2_available,
@@ -46,7 +50,8 @@ from ...utils import (
46
  logging,
47
  replace_return_docstrings,
48
  )
49
- from ...utils.import_utils import is_torch_fx_available
 
50
  from .configuration_llama import LlamaConfig
51
 
52
 
@@ -234,16 +239,19 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
234
  k_embed = (k * cos) + (rotate_half(k) * sin)
235
  return q_embed, k_embed
236
 
 
237
  def activation_quant(x):
238
  scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
239
  y = (x * scale).round().clamp_(-128, 127) / scale
240
  return y
241
 
 
242
  def weight_quant(w):
243
  scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
244
  u = (w * scale).round().clamp_(-1, 1) / scale
245
  return u
246
 
 
247
  class BitLinear(nn.Linear):
248
  def forward(self, x):
249
  w = self.weight
@@ -252,6 +260,7 @@ class BitLinear(nn.Linear):
252
  w_quant = w + (weight_quant(w) - w).detach()
253
  return F.linear(x_quant, w_quant)
254
 
 
255
  class LlamaMLP(nn.Module):
256
  def __init__(self, config):
257
  super().__init__()
 
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
 
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import (
34
  AttentionMaskConverter,
35
  _prepare_4d_attention_mask,
36
  _prepare_4d_causal_attention_mask,
37
  )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ CausalLMOutputWithPast,
41
+ SequenceClassifierOutputWithPast,
42
+ )
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
45
+ from transformers.utils import (
46
  add_start_docstrings,
47
  add_start_docstrings_to_model_forward,
48
  is_flash_attn_2_available,
 
50
  logging,
51
  replace_return_docstrings,
52
  )
53
+ from transformers.utils.import_utils import is_torch_fx_available
54
+
55
  from .configuration_llama import LlamaConfig
56
 
57
 
 
239
  k_embed = (k * cos) + (rotate_half(k) * sin)
240
  return q_embed, k_embed
241
 
242
+
243
  def activation_quant(x):
244
  scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
245
  y = (x * scale).round().clamp_(-128, 127) / scale
246
  return y
247
 
248
+
249
  def weight_quant(w):
250
  scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
251
  u = (w * scale).round().clamp_(-1, 1) / scale
252
  return u
253
 
254
+
255
  class BitLinear(nn.Linear):
256
  def forward(self, x):
257
  w = self.weight
 
260
  w_quant = w + (weight_quant(w) - w).detach()
261
  return F.linear(x_quant, w_quant)
262
 
263
+
264
  class LlamaMLP(nn.Module):
265
  def __init__(self, config):
266
  super().__init__()