Husnain commited on
Commit
a39e60f
1 Parent(s): 99c68b6

♻️ [Refactor] Move STOP_SEQUENCES_MAP and TOKEN_LIMIT_MAP to constants

Browse files
Files changed (1) hide show
  1. networks/message_streamer.py +22 -36
networks/message_streamer.py CHANGED
@@ -1,49 +1,37 @@
1
  import json
2
  import re
3
  import requests
 
4
  from tiktoken import get_encoding as tiktoken_get_encoding
 
 
 
 
 
 
 
 
5
  from messagers.message_outputer import OpenaiStreamOutputer
6
  from utils.logger import logger
7
  from utils.enver import enver
8
 
9
 
10
  class MessageStreamer:
11
- MODEL_MAP = {
12
- "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended]
13
- "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast
14
- "nous-mixtral-8x7b": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
15
- "gemma-7b": "google/gemma-7b-it",
16
- # "openchat-3.5": "openchat/openchat-3.5-1210", # 68.89, fast
17
- # "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # ❌ Too Slow
18
- # "llama-70b": "meta-llama/Llama-2-70b-chat-hf", # ❌ Require Pro User
19
- # "codellama-34b": "codellama/CodeLlama-34b-Instruct-hf", # ❌ Low Score
20
- # "falcon-180b": "tiiuae/falcon-180B-chat", # ❌ Require Pro User
21
- "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
22
- }
23
- STOP_SEQUENCES_MAP = {
24
- "mixtral-8x7b": "</s>",
25
- "mistral-7b": "</s>",
26
- "nous-mixtral-8x7b": "<|im_end|>",
27
- "openchat-3.5": "<|end_of_turn|>",
28
- "gemma-7b": "<eos>",
29
- }
30
- TOKEN_LIMIT_MAP = {
31
- "mixtral-8x7b": 32768,
32
- "mistral-7b": 32768,
33
- "nous-mixtral-8x7b": 32768,
34
- "openchat-3.5": 8192,
35
- "gemma-7b": 8192,
36
- }
37
- TOKEN_RESERVED = 100
38
 
39
  def __init__(self, model: str):
40
- if model in self.MODEL_MAP.keys():
41
  self.model = model
42
  else:
43
  self.model = "default"
44
- self.model_fullname = self.MODEL_MAP[self.model]
45
  self.message_outputer = OpenaiStreamOutputer()
46
- self.tokenizer = tiktoken_get_encoding("cl100k_base")
 
 
 
 
 
 
47
 
48
  def parse_line(self, line):
49
  line = line.decode("utf-8")
@@ -94,9 +82,7 @@ class MessageStreamer:
94
  top_p = min(top_p, 0.99)
95
 
96
  token_limit = int(
97
- self.TOKEN_LIMIT_MAP[self.model]
98
- - self.TOKEN_RESERVED
99
- - self.count_tokens(prompt) * 1.35
100
  )
101
  if token_limit <= 0:
102
  raise ValueError("Prompt exceeded token limit!")
@@ -127,8 +113,8 @@ class MessageStreamer:
127
  "stream": True,
128
  }
129
 
130
- if self.model in self.STOP_SEQUENCES_MAP.keys():
131
- self.stop_sequences = self.STOP_SEQUENCES_MAP[self.model]
132
  # self.request_body["parameters"]["stop_sequences"] = [
133
  # self.STOP_SEQUENCES[self.model]
134
  # ]
@@ -178,7 +164,7 @@ class MessageStreamer:
178
  logger.back(content, end="")
179
  final_content += content
180
 
181
- if self.model in self.STOP_SEQUENCES_MAP.keys():
182
  final_content = final_content.replace(self.stop_sequences, "")
183
 
184
  final_content = final_content.strip()
 
1
  import json
2
  import re
3
  import requests
4
+
5
  from tiktoken import get_encoding as tiktoken_get_encoding
6
+ from transformers import AutoTokenizer
7
+
8
+ from constants.models import (
9
+ MODEL_MAP,
10
+ STOP_SEQUENCES_MAP,
11
+ TOKEN_LIMIT_MAP,
12
+ TOKEN_RESERVED,
13
+ )
14
  from messagers.message_outputer import OpenaiStreamOutputer
15
  from utils.logger import logger
16
  from utils.enver import enver
17
 
18
 
19
  class MessageStreamer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def __init__(self, model: str):
22
+ if model in MODEL_MAP.keys():
23
  self.model = model
24
  else:
25
  self.model = "default"
26
+ self.model_fullname = MODEL_MAP[self.model]
27
  self.message_outputer = OpenaiStreamOutputer()
28
+
29
+ if self.model == "gemma-7b":
30
+ # this is not wrong, as repo `google/gemma-7b-it` is gated and must authenticate to access it
31
+ # so I use mistral-7b as a fallback
32
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_MAP["mistral-7b"])
33
+ else:
34
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
35
 
36
  def parse_line(self, line):
37
  line = line.decode("utf-8")
 
82
  top_p = min(top_p, 0.99)
83
 
84
  token_limit = int(
85
+ TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
 
 
86
  )
87
  if token_limit <= 0:
88
  raise ValueError("Prompt exceeded token limit!")
 
113
  "stream": True,
114
  }
115
 
116
+ if self.model in STOP_SEQUENCES_MAP.keys():
117
+ self.stop_sequences = STOP_SEQUENCES_MAP[self.model]
118
  # self.request_body["parameters"]["stop_sequences"] = [
119
  # self.STOP_SEQUENCES[self.model]
120
  # ]
 
164
  logger.back(content, end="")
165
  final_content += content
166
 
167
+ if self.model in STOP_SEQUENCES_MAP.keys():
168
  final_content = final_content.replace(self.stop_sequences, "")
169
 
170
  final_content = final_content.strip()