jingyaogong commited on
Commit
c63b0c9
1 Parent(s): 6d52f6f

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +33 -29
model.py CHANGED
@@ -1,6 +1,8 @@
1
  import math
2
  import struct
3
  import inspect
 
 
4
  from .LMConfig import LMConfig
5
  from typing import Any, Optional, Tuple
6
  import numpy as np
@@ -80,26 +82,15 @@ class Attention(nn.Module):
80
  self.dropout = args.dropout
81
  self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
82
 
83
- if not self.flash:
84
- # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
85
- mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
86
- mask = torch.triu(mask, diagonal=1)
87
- self.register_buffer("mask", mask)
88
 
89
- def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, use_kv_cache=False):
90
  bsz, seqlen, _ = x.shape
91
- if use_kv_cache and self.eval():
92
- if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1:
93
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
94
- else:
95
- token = x[:, -1:, :]
96
- xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
97
- xk = torch.cat((self.k_cache, self.wk(token)), dim=1)
98
- xv = torch.cat((self.v_cache, self.wv(token)), dim=1)
99
 
100
- self.k_cache, self.v_cache = xk, xv
101
- else:
102
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
103
 
104
  xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
105
  xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
@@ -107,6 +98,13 @@ class Attention(nn.Module):
107
 
108
  xq, xk = apply_rotary_emb(xq, xk, pos_cis)
109
 
 
 
 
 
 
 
 
110
  xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
111
  xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
112
 
@@ -114,13 +112,12 @@ class Attention(nn.Module):
114
  xk = xk.transpose(1, 2)
115
  xv = xv.transpose(1, 2)
116
 
117
- if self.flash:
118
  output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
119
  dropout_p=self.dropout if self.training else 0.0,
120
  is_causal=True)
121
  else:
122
  scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
123
- assert hasattr(self, 'mask')
124
  scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
125
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
126
  scores = self.attn_dropout(scores)
@@ -304,8 +301,8 @@ class TransformerBlock(nn.Module):
304
  dropout=args.dropout,
305
  )
306
 
307
- def forward(self, x, pos_cis, use_kv_cache=False):
308
- h = x + self.attention(self.attention_norm(x), pos_cis, use_kv_cache)
309
  out = h + self.feed_forward(self.ffn_norm(h))
310
  return out
311
 
@@ -351,18 +348,21 @@ class Transformer(PreTrainedModel):
351
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
352
 
353
  def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
354
- use_kv_cache=False, **keyargs):
 
355
  if 'input_ids' in keyargs:
356
  tokens = keyargs['input_ids']
357
  if 'attention_mask' in keyargs:
358
  targets = keyargs['attention_mask']
 
 
359
 
360
  _bsz, seqlen = tokens.shape
361
  h = self.tok_embeddings(tokens)
362
  h = self.dropout(h)
363
- pos_cis = self.pos_cis[:seqlen]
364
  for idx, layer in enumerate(self.layers):
365
- h = layer(h, pos_cis, use_kv_cache)
366
 
367
  h = self.norm(h)
368
 
@@ -375,20 +375,24 @@ class Transformer(PreTrainedModel):
375
 
376
  self.OUT.__setitem__('logits', logits)
377
  self.OUT.__setitem__('last_loss', self.last_loss)
378
-
379
  return self.OUT
380
 
381
  @torch.inference_mode()
382
- def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
383
- use_kv_cache=True):
384
  index = idx.shape[1]
 
385
  while idx.shape[1] < max_new_tokens - 1:
386
- inference_res = self(idx, use_kv_cache=use_kv_cache)
 
 
 
 
387
  logits = inference_res.logits
388
  logits = logits[:, -1, :]
389
 
390
  for token in set(idx.tolist()[0]):
391
- logits[:, token] /= repetition_penalty
392
 
393
  if temperature == 0.0:
394
  _, idx_next = torch.topk(logits, k=1, dim=-1)
 
1
  import math
2
  import struct
3
  import inspect
4
+ import time
5
+
6
  from .LMConfig import LMConfig
7
  from typing import Any, Optional, Tuple
8
  import numpy as np
 
82
  self.dropout = args.dropout
83
  self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
84
 
85
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
86
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
87
+ mask = torch.triu(mask, diagonal=1)
88
+ self.register_buffer("mask", mask)
 
89
 
90
+ def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
91
  bsz, seqlen, _ = x.shape
 
 
 
 
 
 
 
 
92
 
93
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
 
 
94
 
95
  xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
96
  xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
 
98
 
99
  xq, xk = apply_rotary_emb(xq, xk, pos_cis)
100
 
101
+ # 更高效的kv_cache实现
102
+ if kv_cache and self.eval():
103
+ if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
104
+ xk = torch.cat((self.k_cache, xk), dim=1)
105
+ xv = torch.cat((self.v_cache, xv), dim=1)
106
+ self.k_cache, self.v_cache = xk, xv
107
+
108
  xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
109
  xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
110
 
 
112
  xk = xk.transpose(1, 2)
113
  xv = xv.transpose(1, 2)
114
 
115
+ if self.flash and seqlen != 1:
116
  output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
117
  dropout_p=self.dropout if self.training else 0.0,
118
  is_causal=True)
119
  else:
120
  scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
 
121
  scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
122
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
123
  scores = self.attn_dropout(scores)
 
301
  dropout=args.dropout,
302
  )
303
 
304
+ def forward(self, x, pos_cis, kv_cache=False):
305
+ h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
306
  out = h + self.feed_forward(self.ffn_norm(h))
307
  return out
308
 
 
348
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
349
 
350
  def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
351
+ kv_cache=False, **keyargs):
352
+ current_idx = 0
353
  if 'input_ids' in keyargs:
354
  tokens = keyargs['input_ids']
355
  if 'attention_mask' in keyargs:
356
  targets = keyargs['attention_mask']
357
+ if 'current_idx' in keyargs:
358
+ current_idx = int(keyargs['current_idx'])
359
 
360
  _bsz, seqlen = tokens.shape
361
  h = self.tok_embeddings(tokens)
362
  h = self.dropout(h)
363
+ pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
364
  for idx, layer in enumerate(self.layers):
365
+ h = layer(h, pos_cis, kv_cache)
366
 
367
  h = self.norm(h)
368
 
 
375
 
376
  self.OUT.__setitem__('logits', logits)
377
  self.OUT.__setitem__('last_loss', self.last_loss)
 
378
  return self.OUT
379
 
380
  @torch.inference_mode()
381
+ def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
382
+ # rp: repetition_penalty
383
  index = idx.shape[1]
384
+ init_inference = True
385
  while idx.shape[1] < max_new_tokens - 1:
386
+ if init_inference or not kv_cache:
387
+ inference_res, init_inference = self(idx, kv_cache=kv_cache), False
388
+ else:
389
+ inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
390
+
391
  logits = inference_res.logits
392
  logits = logits[:, -1, :]
393
 
394
  for token in set(idx.tolist()[0]):
395
+ logits[:, token] /= rp
396
 
397
  if temperature == 0.0:
398
  _, idx_next = torch.topk(logits, k=1, dim=-1)