Husnain commited on
Commit
97134c0
1 Parent(s): dbc5434

⚡ [Enhance] Use nous-mixtral-8x7b as default model

Browse files
Files changed (1) hide show
  1. networks/huggingface_streamer.py +7 -30
networks/huggingface_streamer.py CHANGED
@@ -2,18 +2,11 @@ import json
2
  import re
3
  import requests
4
 
5
-
6
  from tclogger import logger
7
- from transformers import AutoTokenizer
8
-
9
- from constants.models import (
10
- MODEL_MAP,
11
- STOP_SEQUENCES_MAP,
12
- TOKEN_LIMIT_MAP,
13
- TOKEN_RESERVED,
14
- )
15
  from constants.envs import PROXIES
16
  from messagers.message_outputer import OpenaiStreamOutputer
 
17
 
18
 
19
  class HuggingfaceStreamer:
@@ -21,33 +14,21 @@ class HuggingfaceStreamer:
21
  if model in MODEL_MAP.keys():
22
  self.model = model
23
  else:
24
- self.model = "mixtral-8x7b"
25
  self.model_fullname = MODEL_MAP[self.model]
26
  self.message_outputer = OpenaiStreamOutputer(model=self.model)
27
 
28
- if self.model == "gemma-1.1-7b":
29
- # this is not wrong, as repo `google/gemma-1.1-7b-it` is gated and must authenticate to access it
30
- # so I use mistral-7b as a fallback
31
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_MAP["mistral-7b"])
32
- else:
33
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
34
-
35
  def parse_line(self, line):
36
  line = line.decode("utf-8")
37
  line = re.sub(r"data:\s*", "", line)
38
  data = json.loads(line)
 
39
  try:
40
  content = data["token"]["text"]
41
  except:
42
  logger.err(data)
43
  return content
44
 
45
- def count_tokens(self, text):
46
- tokens = self.tokenizer.encode(text)
47
- token_count = len(tokens)
48
- logger.note(f"Prompt Token Count: {token_count}")
49
- return token_count
50
-
51
  def chat_response(
52
  self,
53
  prompt: str = None,
@@ -80,16 +61,12 @@ class HuggingfaceStreamer:
80
  top_p = max(top_p, 0.01)
81
  top_p = min(top_p, 0.99)
82
 
83
- token_limit = int(
84
- TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
85
- )
86
- if token_limit <= 0:
87
- raise ValueError("Prompt exceeded token limit!")
88
 
89
  if max_new_tokens is None or max_new_tokens <= 0:
90
- max_new_tokens = token_limit
91
  else:
92
- max_new_tokens = min(max_new_tokens, token_limit)
93
 
94
  # References:
95
  # huggingface_hub/inference/_client.py:
 
2
  import re
3
  import requests
4
 
 
5
  from tclogger import logger
6
+ from constants.models import MODEL_MAP, STOP_SEQUENCES_MAP
 
 
 
 
 
 
 
7
  from constants.envs import PROXIES
8
  from messagers.message_outputer import OpenaiStreamOutputer
9
+ from messagers.token_checker import TokenChecker
10
 
11
 
12
  class HuggingfaceStreamer:
 
14
  if model in MODEL_MAP.keys():
15
  self.model = model
16
  else:
17
+ self.model = "nous-mixtral-8x7b"
18
  self.model_fullname = MODEL_MAP[self.model]
19
  self.message_outputer = OpenaiStreamOutputer(model=self.model)
20
 
 
 
 
 
 
 
 
21
  def parse_line(self, line):
22
  line = line.decode("utf-8")
23
  line = re.sub(r"data:\s*", "", line)
24
  data = json.loads(line)
25
+ content = ""
26
  try:
27
  content = data["token"]["text"]
28
  except:
29
  logger.err(data)
30
  return content
31
 
 
 
 
 
 
 
32
  def chat_response(
33
  self,
34
  prompt: str = None,
 
61
  top_p = max(top_p, 0.01)
62
  top_p = min(top_p, 0.99)
63
 
64
+ checker = TokenChecker(input_str=prompt, model=self.model)
 
 
 
 
65
 
66
  if max_new_tokens is None or max_new_tokens <= 0:
67
+ max_new_tokens = checker.get_token_redundancy()
68
  else:
69
+ max_new_tokens = min(max_new_tokens, checker.get_token_redundancy())
70
 
71
  # References:
72
  # huggingface_hub/inference/_client.py: