yangapku commited on
Commit
e7ecd2a
1 Parent(s): 86055d1

update cpu support, readme and convert_tokens_to_string

Browse files
Files changed (3) hide show
  1. README.md +5 -0
  2. modeling_qwen.py +63 -24
  3. tokenization_qwen.py +18 -12
README.md CHANGED
@@ -73,11 +73,16 @@ You can easily call the model with the following code:
73
  from transformers import AutoModelForCausalLM, AutoTokenizer
74
  from transformers.generation import GenerationConfig
75
 
 
 
 
76
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
77
  # use bf16
78
  # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, bf16=True).eval()
79
  # use fp16
80
  # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, fp16=True).eval()
 
 
81
  # use fp32
82
  model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True).eval()
83
  model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
 
73
  from transformers import AutoModelForCausalLM, AutoTokenizer
74
  from transformers.generation import GenerationConfig
75
 
76
+ # Note: our tokenizer rejects attacks and so that you cannot input special tokens like <|endoftext|> or it will throw an error.
77
+ # To remove the strategy, you can add `allowed_special`, which accepts the string "all" or a `set` of special tokens.
78
+ # For example: tokens = tokenizer(text, allowed_special="all")
79
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
80
  # use bf16
81
  # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, bf16=True).eval()
82
  # use fp16
83
  # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, fp16=True).eval()
84
+ # use cpu only
85
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="cpu", trust_remote_code=True).eval()
86
  # use fp32
87
  model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True).eval()
88
  model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
modeling_qwen.py CHANGED
@@ -15,6 +15,7 @@ from torch.cuda.amp import autocast
15
  from torch.nn import CrossEntropyLoss
16
  from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
  from transformers.generation.logits_process import LogitsProcessorList
 
18
  if TYPE_CHECKING:
19
  from transformers.generation.streamers import BaseStreamer
20
  from transformers.generation.utils import GenerateOutput
@@ -38,15 +39,19 @@ try:
38
  use_flash_rotary = True
39
  except ImportError:
40
  use_flash_rotary = False
41
- print("Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
42
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary")
 
 
43
 
44
  try:
45
  from flash_attn.ops.rms_norm import rms_norm
46
  except ImportError:
47
  rms_norm = None
48
- print("Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
49
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm")
 
 
50
 
51
  from .configuration_qwen import QWenConfig
52
  from .qwen_generation_utils import (
@@ -69,8 +74,10 @@ try:
69
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func
70
  except ImportError:
71
  flash_attn_unpadded_func = None
72
- print("Warning: import flash_attn fail, please install FlashAttention "
73
- "https://github.com/Dao-AILab/flash-attention")
 
 
74
 
75
 
76
  class FlashSelfAttention(torch.nn.Module):
@@ -177,8 +184,12 @@ class QWenAttention(nn.Module):
177
  config.hidden_size, self.projection_size, bias=not config.no_bias
178
  )
179
 
180
- self.is_fp32 = not(config.bf16 or config.fp16)
181
- if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
 
 
 
 
182
  self.core_attention_flash = FlashSelfAttention(
183
  causal=True, attention_dropout=config.attn_pdrop
184
  )
@@ -197,14 +208,15 @@ class QWenAttention(nn.Module):
197
  if self.rotary_ndims is not None
198
  else self.hidden_size_per_attention_head
199
  )
200
- self.rotary_emb = RotaryEmbedding(
201
- dim, base=config.rotary_emb_base
202
- )
203
 
204
  self.use_dynamic_ntk = config.use_dynamic_ntk
205
  self.use_logn_attn = config.use_logn_attn
206
 
207
- logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
 
 
 
208
  self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
209
  self._ntk_cached = 1.0
210
 
@@ -335,14 +347,20 @@ class QWenAttention(nn.Module):
335
  if layer_past:
336
  # layer past[0] shape: bs * seq_len * head_num * dim
337
  kv_seq_len += layer_past[0].shape[1]
338
- if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1] and not self.training:
 
 
 
 
339
  context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
340
  ntk_alpha = 2 ** math.ceil(context_value) - 1
341
  ntk_alpha = max(ntk_alpha, 1)
342
  self._ntk_cached = ntk_alpha
343
  else:
344
  ntk_alpha = self._ntk_cached
345
- rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(hidden_states.device)
 
 
346
 
347
  if rotary_pos_emb is not None:
348
  if isinstance(rotary_pos_emb, tuple):
@@ -377,7 +395,12 @@ class QWenAttention(nn.Module):
377
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
378
  query = query * logn_tensor.expand_as(query)
379
 
380
- if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
 
 
 
 
 
381
  q, k, v = query, key, value
382
  context_layer = self.core_attention_flash(q, k, v)
383
 
@@ -398,7 +421,11 @@ class QWenAttention(nn.Module):
398
  attn_output = self.c_proj(context_layer)
399
  outputs = (attn_output, present)
400
  if output_attentions:
401
- if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
 
 
 
 
402
  raise ValueError("Cannot output attentions while using flash-attn")
403
  else:
404
  outputs += (attn_weight,)
@@ -750,7 +777,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
750
  super().__init__(config)
751
  self.transformer = QWenModel(config)
752
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
753
- assert not(config.bf16 and config.fp16), ("In config, bf16 and fp16 cannot both be true")
 
 
754
  if config.bf16:
755
  self.transformer.bfloat16()
756
  self.lm_head.bfloat16()
@@ -929,21 +958,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
929
  generation_config: Optional[GenerationConfig] = None,
930
  logits_processor: Optional[LogitsProcessorList] = None,
931
  stopping_criteria: Optional[StoppingCriteriaList] = None,
932
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 
 
933
  synced_gpus: Optional[bool] = None,
934
  streamer: Optional["BaseStreamer"] = None,
935
  **kwargs,
936
  ) -> Union[GenerateOutput, torch.LongTensor]:
937
  # Process stop_words_ids.
938
- stop_words_ids = kwargs.pop('stop_words_ids', None)
939
  if stop_words_ids is None and generation_config is not None:
940
- stop_words_ids = getattr(generation_config, 'stop_words_ids', None)
941
  if stop_words_ids is None:
942
- stop_words_ids = getattr(self.generation_config, 'stop_words_ids', None)
943
 
944
  if stop_words_ids is not None:
945
  stop_words_logits_processor = StopWordsLogitsProcessor(
946
- stop_words_ids=stop_words_ids, eos_token_id=self.generation_config.eos_token_id)
 
 
947
  if logits_processor is None:
948
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
949
  else:
@@ -978,7 +1011,13 @@ class RotaryEmbedding(torch.nn.Module):
978
  seqlen = max_seq_len + offset
979
  if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
980
  base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
981
- self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim))
 
 
 
 
 
 
982
  self._seq_len_cached = seqlen
983
  self._ntk_alpha_cached = ntk_alpha
984
  seq = torch.arange(seqlen, device=self.inv_freq.device)
@@ -1028,7 +1067,7 @@ class RMSNorm(torch.nn.Module):
1028
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1029
 
1030
  def forward(self, x):
1031
- if rms_norm is not None:
1032
  return rms_norm(x, self.weight, self.eps)
1033
  else:
1034
  output = self._norm(x.float()).type_as(x)
 
15
  from torch.nn import CrossEntropyLoss
16
  from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
  from transformers.generation.logits_process import LogitsProcessorList
18
+
19
  if TYPE_CHECKING:
20
  from transformers.generation.streamers import BaseStreamer
21
  from transformers.generation.utils import GenerateOutput
 
39
  use_flash_rotary = True
40
  except ImportError:
41
  use_flash_rotary = False
42
+ print(
43
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
44
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
45
+ )
46
 
47
  try:
48
  from flash_attn.ops.rms_norm import rms_norm
49
  except ImportError:
50
  rms_norm = None
51
+ print(
52
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
53
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
54
+ )
55
 
56
  from .configuration_qwen import QWenConfig
57
  from .qwen_generation_utils import (
 
74
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func
75
  except ImportError:
76
  flash_attn_unpadded_func = None
77
+ print(
78
+ "Warning: import flash_attn fail, please install FlashAttention "
79
+ "https://github.com/Dao-AILab/flash-attention"
80
+ )
81
 
82
 
83
  class FlashSelfAttention(torch.nn.Module):
 
184
  config.hidden_size, self.projection_size, bias=not config.no_bias
185
  )
186
 
187
+ self.is_fp32 = not (config.bf16 or config.fp16)
188
+ if (
189
+ self.use_flash_attn
190
+ and flash_attn_unpadded_func is not None
191
+ and not self.is_fp32
192
+ ):
193
  self.core_attention_flash = FlashSelfAttention(
194
  causal=True, attention_dropout=config.attn_pdrop
195
  )
 
208
  if self.rotary_ndims is not None
209
  else self.hidden_size_per_attention_head
210
  )
211
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
 
 
212
 
213
  self.use_dynamic_ntk = config.use_dynamic_ntk
214
  self.use_logn_attn = config.use_logn_attn
215
 
216
+ logn_list = [
217
+ math.log(i, self.seq_length) if i > self.seq_length else 1
218
+ for i in range(1, 32768)
219
+ ]
220
  self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
221
  self._ntk_cached = 1.0
222
 
 
347
  if layer_past:
348
  # layer past[0] shape: bs * seq_len * head_num * dim
349
  kv_seq_len += layer_past[0].shape[1]
350
+ if (
351
+ self.use_dynamic_ntk
352
+ and kv_seq_len == hidden_states.size()[1]
353
+ and not self.training
354
+ ):
355
  context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
356
  ntk_alpha = 2 ** math.ceil(context_value) - 1
357
  ntk_alpha = max(ntk_alpha, 1)
358
  self._ntk_cached = ntk_alpha
359
  else:
360
  ntk_alpha = self._ntk_cached
361
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
362
+ hidden_states.device
363
+ )
364
 
365
  if rotary_pos_emb is not None:
366
  if isinstance(rotary_pos_emb, tuple):
 
395
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
396
  query = query * logn_tensor.expand_as(query)
397
 
398
+ if (
399
+ self.use_flash_attn
400
+ and flash_attn_unpadded_func is not None
401
+ and not self.is_fp32
402
+ and query.is_cuda
403
+ ):
404
  q, k, v = query, key, value
405
  context_layer = self.core_attention_flash(q, k, v)
406
 
 
421
  attn_output = self.c_proj(context_layer)
422
  outputs = (attn_output, present)
423
  if output_attentions:
424
+ if (
425
+ self.use_flash_attn
426
+ and flash_attn_unpadded_func is not None
427
+ and not self.is_fp32
428
+ ):
429
  raise ValueError("Cannot output attentions while using flash-attn")
430
  else:
431
  outputs += (attn_weight,)
 
777
  super().__init__(config)
778
  self.transformer = QWenModel(config)
779
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
780
+ assert not (
781
+ config.bf16 and config.fp16
782
+ ), "In config, bf16 and fp16 cannot both be true"
783
  if config.bf16:
784
  self.transformer.bfloat16()
785
  self.lm_head.bfloat16()
 
958
  generation_config: Optional[GenerationConfig] = None,
959
  logits_processor: Optional[LogitsProcessorList] = None,
960
  stopping_criteria: Optional[StoppingCriteriaList] = None,
961
+ prefix_allowed_tokens_fn: Optional[
962
+ Callable[[int, torch.Tensor], List[int]]
963
+ ] = None,
964
  synced_gpus: Optional[bool] = None,
965
  streamer: Optional["BaseStreamer"] = None,
966
  **kwargs,
967
  ) -> Union[GenerateOutput, torch.LongTensor]:
968
  # Process stop_words_ids.
969
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
970
  if stop_words_ids is None and generation_config is not None:
971
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
972
  if stop_words_ids is None:
973
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
974
 
975
  if stop_words_ids is not None:
976
  stop_words_logits_processor = StopWordsLogitsProcessor(
977
+ stop_words_ids=stop_words_ids,
978
+ eos_token_id=self.generation_config.eos_token_id,
979
+ )
980
  if logits_processor is None:
981
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
982
  else:
 
1011
  seqlen = max_seq_len + offset
1012
  if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1013
  base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1014
+ self.inv_freq = 1.0 / (
1015
+ base
1016
+ ** (
1017
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1018
+ / self.dim
1019
+ )
1020
+ )
1021
  self._seq_len_cached = seqlen
1022
  self._ntk_alpha_cached = ntk_alpha
1023
  seq = torch.arange(seqlen, device=self.inv_freq.device)
 
1067
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1068
 
1069
  def forward(self, x):
1070
+ if rms_norm is not None and x.is_cuda:
1071
  return rms_norm(x, self.weight, self.eps)
1072
  else:
1073
  output = self._norm(x.float()).type_as(x)
tokenization_qwen.py CHANGED
@@ -22,7 +22,6 @@ logger = logging.getLogger(__name__)
22
 
23
  VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
24
 
25
-
26
  class QWenTokenizer(PreTrainedTokenizer):
27
  """QWen tokenizer."""
28
 
@@ -126,6 +125,7 @@ class QWenTokenizer(PreTrainedTokenizer):
126
  self.mergeable_ranks = mergeable_ranks
127
  self.encoder = self.mergeable_ranks
128
  self.decoder = {v: k for k, v in self.encoder.items()}
 
129
  self.tokenizer = enc # type: tiktoken.Encoding
130
  self.eod_id = self.tokenizer.eot_token
131
  self.im_start_id = special_tokens[IMSTART]
@@ -182,29 +182,32 @@ class QWenTokenizer(PreTrainedTokenizer):
182
  text (`str`):
183
  The sequence to be encoded.
184
  kwargs (additional keyword arguments, *optional*):
185
- Will be passed to the underlying model specific encode method. See details in
186
- [`~PreTrainedTokenizerBase.__call__`]
 
 
187
 
188
  Returns:
189
  `List[str]`: The list of tokens.
190
  """
191
  tokens = []
192
  text = unicodedata.normalize("NFC", text)
193
- for t in self.tokenizer.encode_ordinary(text):
 
194
  tokens.append(self.decoder[t])
 
195
  return tokens
196
 
197
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
198
  """
199
  Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
200
  often want to remove sub-word tokenization artifacts at the same time.
201
  """
202
- text = "".join(tokens)
203
- text = bytearray([self.byte_decoder[c] for c in text]).decode(
204
- "utf-8", errors=self.errors
205
- )
206
- return text
207
-
208
  @property
209
  def vocab_size(self):
210
  return self.tokenizer.n_vocab
@@ -216,7 +219,10 @@ class QWenTokenizer(PreTrainedTokenizer):
216
 
217
  def _convert_token_to_id(self, token: str) -> int:
218
  """Converts a token to an id using the vocab."""
219
- return self.encoder.get(token.encode('UTF-8'), self.tokenizer.encode(self.unk_token, allowed_special='all')[0])
 
 
 
220
 
221
  @property
222
  def all_special_tokens(self) -> List[str]:
 
22
 
23
  VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
24
 
 
25
  class QWenTokenizer(PreTrainedTokenizer):
26
  """QWen tokenizer."""
27
 
 
125
  self.mergeable_ranks = mergeable_ranks
126
  self.encoder = self.mergeable_ranks
127
  self.decoder = {v: k for k, v in self.encoder.items()}
128
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
129
  self.tokenizer = enc # type: tiktoken.Encoding
130
  self.eod_id = self.tokenizer.eot_token
131
  self.im_start_id = special_tokens[IMSTART]
 
182
  text (`str`):
183
  The sequence to be encoded.
184
  kwargs (additional keyword arguments, *optional*):
185
+ Will be passed to the underlying model specific encode method.
186
+ Tiktoken allows users to allow the tokenization of special tokens with the following args:
187
+ `allowed_special`: set to 'all' or a `set` of special tokens.
188
+ `disallowed_special`: set to 'all' or a `Collection` of special tokens. NOT RECOMMENDED, AS IT MAY BE CONFLICTED WITH `allowed_special`.
189
 
190
  Returns:
191
  `List[str]`: The list of tokens.
192
  """
193
  tokens = []
194
  text = unicodedata.normalize("NFC", text)
195
+
196
+ for t in self.tokenizer.encode(text, **kwargs):
197
  tokens.append(self.decoder[t])
198
+
199
  return tokens
200
 
201
+ def convert_tokens_to_string(self, tokens: List[bytes]) -> str:
202
  """
203
  Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
204
  often want to remove sub-word tokenization artifacts at the same time.
205
  """
206
+ text = b""
207
+ for token in tokens:
208
+ text += token
209
+ return text.decode('utf-8')
210
+
 
211
  @property
212
  def vocab_size(self):
213
  return self.tokenizer.n_vocab
 
219
 
220
  def _convert_token_to_id(self, token: str) -> int:
221
  """Converts a token to an id using the vocab."""
222
+ return self.encoder.get(
223
+ token.encode("UTF-8"),
224
+ self.tokenizer.encode(self.unk_token, allowed_special="all")[0],
225
+ )
226
 
227
  @property
228
  def all_special_tokens(self) -> List[str]: