wanicca commited on
Commit
ba9988f
1 Parent(s): 629f62e

Add world support

Browse files
Files changed (4) hide show
  1. app.py +3 -1
  2. rwkv_tokenizer.py +103 -0
  3. rwkv_vocab_v20230424.txt +0 -0
  4. utils.py +11 -4
app.py CHANGED
@@ -13,6 +13,7 @@ desc = f'''链接:<a href='https://colab.research.google.com/drive/1J1gLMMMA8G
13
 
14
  parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
15
  parser.add_argument('--share',action='store_true')
 
16
  parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch11.pth")
17
  parser.add_argument('--model_path',type=str,default=None,help="local model path")
18
  parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
@@ -40,7 +41,8 @@ if torch.cuda.is_available() and torch.cuda.device_count()>0:
40
  else:
41
  model = RWKV(model=model_path, strategy='cpu bf16',**lora_kwargs)
42
  from utils import PIPELINE, PIPELINE_ARGS
43
- pipeline = PIPELINE(model, "20B_tokenizer.json")
 
44
 
45
  def infer(
46
  ctx,
 
13
 
14
  parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
15
  parser.add_argument('--share',action='store_true')
16
+ parser.add_argument("--world",type=bool, default=False)
17
  parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch11.pth")
18
  parser.add_argument('--model_path',type=str,default=None,help="local model path")
19
  parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
 
41
  else:
42
  model = RWKV(model=model_path, strategy='cpu bf16',**lora_kwargs)
43
  from utils import PIPELINE, PIPELINE_ARGS
44
+ tokenizer_file = "rwkv_vocab_v20230424" if args.world else "20B_tokenizer.json"
45
+ pipeline = PIPELINE(model, tokenizer_file)
46
 
47
  def infer(
48
  ctx,
rwkv_tokenizer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ class TRIE:
6
+ __slots__ = tuple("ch,to,values,front".split(","))
7
+ to:list
8
+ values:set
9
+ def __init__(self, front=None, ch=None):
10
+ self.ch = ch
11
+ self.to = [None for ch in range(256)]
12
+ self.values = set()
13
+ self.front = front
14
+
15
+ def __repr__(self):
16
+ fr = self
17
+ ret = []
18
+ while(fr!=None):
19
+ if(fr.ch!=None):
20
+ ret.append(fr.ch)
21
+ fr = fr.front
22
+ return "<TRIE %s %s>"%(ret[::-1], self.values)
23
+
24
+ def add(self, key:bytes, idx:int=0, val=None):
25
+ if(idx == len(key)):
26
+ if(val is None):
27
+ val = key
28
+ self.values.add(val)
29
+ return self
30
+ ch = key[idx]
31
+ if(self.to[ch] is None):
32
+ self.to[ch] = TRIE(front=self, ch=ch)
33
+ return self.to[ch].add(key, idx=idx+1, val=val)
34
+
35
+ def find_longest(self, key:bytes, idx:int=0):
36
+ u:TRIE = self
37
+ ch:int = key[idx]
38
+
39
+ while(u.to[ch] is not None):
40
+ u = u.to[ch]
41
+ idx += 1
42
+ if(u.values):
43
+ ret = idx, u, u.values
44
+ if(idx==len(key)):
45
+ break
46
+ ch = key[idx]
47
+ return ret
48
+
49
+ class TRIE_TOKENIZER():
50
+ def __init__(self, file_name):
51
+ self.idx2token = {}
52
+ sorted = [] # must be already sorted
53
+ with open(file_name, "r", encoding="utf-8") as f:
54
+ lines = f.readlines()
55
+ for l in lines:
56
+ idx = int(l[:l.index(' ')])
57
+ x = eval(l[l.index(' '):l.rindex(' ')])
58
+ x = x.encode("utf-8") if isinstance(x, str) else x
59
+ assert isinstance(x, bytes)
60
+ assert len(x) == int(l[l.rindex(' '):])
61
+ sorted += [x]
62
+ self.idx2token[idx] = x
63
+
64
+ self.token2idx = {}
65
+ for k,v in self.idx2token.items():
66
+ self.token2idx[v] = int(k)
67
+
68
+ self.root = TRIE()
69
+ for t, i in self.token2idx.items():
70
+ _ = self.root.add(t, val=(t, i))
71
+
72
+ def encodeBytes(self, src:bytes):
73
+ idx:int = 0
74
+ tokens = []
75
+ while (idx < len(src)):
76
+ _idx:int = idx
77
+ idx, _, values = self.root.find_longest(src, idx)
78
+ assert(idx != _idx)
79
+ _, token = next(iter(values))
80
+ tokens.append(token)
81
+ return tokens
82
+
83
+ def decodeBytes(self, tokens):
84
+ return b''.join(map(lambda i: self.idx2token[i], tokens))
85
+
86
+ def encode(self, src):
87
+ return self.encodeBytes(src.encode("utf-8"))
88
+
89
+ def decode(self, tokens):
90
+ try:
91
+ return self.decodeBytes(tokens).decode('utf-8')
92
+ except:
93
+ return '\ufffd' # bad utf-8
94
+
95
+ def printTokens(self, tokens):
96
+ for i in tokens:
97
+ s = self.idx2token[i]
98
+ try:
99
+ s = s.decode('utf-8')
100
+ except:
101
+ pass
102
+ print(f'{repr(s)}{i}', end=' ')
103
+ print()
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
utils.py CHANGED
@@ -22,6 +22,9 @@ class PIPELINE():
22
  if WORD_NAME == 'cl100k_base':
23
  import tiktoken
24
  self.tokenizer = tiktoken.get_encoding(WORD_NAME)
 
 
 
25
  else:
26
  from tokenizers import Tokenizer
27
  self.tokenizer = Tokenizer.from_file(WORD_NAME)
@@ -37,10 +40,14 @@ class PIPELINE():
37
  return context
38
 
39
  def encode(self, x):
40
- if 'tiktoken' in str(type(self.tokenizer)):
41
- return self.tokenizer.encode(x)
42
- else:
43
- return self.tokenizer.encode(x).ids
 
 
 
 
44
 
45
  def decode(self, x):
46
  return self.tokenizer.decode(x)
 
22
  if WORD_NAME == 'cl100k_base':
23
  import tiktoken
24
  self.tokenizer = tiktoken.get_encoding(WORD_NAME)
25
+ elif WORD_NAME == 'rwkv_vocab_v20230424':
26
+ from rwkv_tokenizer import TRIE_TOKENIZER
27
+ self.tokenizer = TRIE_TOKENIZER(f'./{WORD_NAME}.txt')
28
  else:
29
  from tokenizers import Tokenizer
30
  self.tokenizer = Tokenizer.from_file(WORD_NAME)
 
40
  return context
41
 
42
  def encode(self, x):
43
+ # if 'tiktoken' in str(type(self.tokenizer)):
44
+ # return self.tokenizer.encode(x)
45
+ # else:
46
+ # return self.tokenizer.encode(x).ids
47
+ encoded = self.tokenizer.encode(x)
48
+ if hasattr(encoded,"ids"):
49
+ encoded = encoded.ids
50
+ return encoded
51
 
52
  def decode(self, x):
53
  return self.tokenizer.decode(x)