Limour commited on
Commit
bce2a0f
1 Parent(s): 7fe785d

Upload 2 files

Browse files
Files changed (1) hide show
  1. llama_cpp_python_streamingllm.py +49 -23
llama_cpp_python_streamingllm.py CHANGED
@@ -31,7 +31,7 @@ def get_complete_UTF8(all_text):
31
  class StreamingLLM(Llama):
32
  def __init__(self, model_path: str, **kwargs):
33
  super().__init__(model_path, **kwargs)
34
- self.venv = [0]
35
 
36
  def str_detokenize(self, tokens) -> str:
37
  return get_complete_UTF8(self.detokenize(tokens))
@@ -39,38 +39,63 @@ class StreamingLLM(Llama):
39
  def kv_cache_seq_trim(self):
40
  self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
41
 
42
- def venv_create(self):
 
 
 
 
43
  self.venv.append(0)
44
- return len(self.venv) - 1
 
45
 
46
- def venv_disband(self):
47
  if len(self.venv) <= 1:
48
- return 0
49
- tmp = self.venv.pop()
50
- self.venv[-1] += tmp
51
- return len(self.venv) - 1
 
 
 
 
 
 
 
 
52
 
53
- def venv_remove(self, venv_idx=None):
54
- if venv_idx is None:
55
- venv_idx = len(self.venv) - 1
56
- if venv_idx <= 0 or venv_idx >= len(self.venv):
57
- return len(self.venv) - 1
58
- if venv_idx == len(self.venv) - 1:
59
- # 最后一层
60
- self.n_tokens -= min(self.venv.pop(), self.n_tokens)
61
- self.kv_cache_seq_trim()
62
- else:
63
- # 非最后一层
64
- n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
65
- n_discard = self.venv.pop(venv_idx)
66
- self.kv_cache_seq_ltrim(n_keep, n_discard)
67
- return len(self.venv) - 1
 
 
 
 
 
 
 
 
68
 
69
  def venv_pop_token(self):
70
  self.n_tokens -= 1
71
  self.venv[-1] -= 1
72
  self.kv_cache_seq_trim()
73
 
 
 
 
 
74
  def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
75
  if n_past < 0:
76
  n_past = self.n_tokens
@@ -274,6 +299,7 @@ class StreamingLLM(Llama):
274
  n_tokens)
275
  self.n_tokens = n_tokens.contents.value
276
  self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
 
277
  return retn
278
 
279
  def save_session(self, filepath: str):
 
31
  class StreamingLLM(Llama):
32
  def __init__(self, model_path: str, **kwargs):
33
  super().__init__(model_path, **kwargs)
34
+ self._venv_init()
35
 
36
  def str_detokenize(self, tokens) -> str:
37
  return get_complete_UTF8(self.detokenize(tokens))
 
39
  def kv_cache_seq_trim(self):
40
  self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
41
 
42
+ def _venv_init(self):
43
+ self.venv = [0]
44
+ self.venv_idx_map = []
45
+
46
+ def venv_create(self, name: str):
47
  self.venv.append(0)
48
+ self.venv_idx_map.append(name)
49
+ return name
50
 
51
+ def venv_disband(self, name_set):
52
  if len(self.venv) <= 1:
53
+ return name_set
54
+ name_set = {x for x in name_set if x in self.venv_idx_map}
55
+ if not name_set:
56
+ return name_set
57
+ while self.venv_idx_map:
58
+ if self.venv_idx_map[0] in name_set:
59
+ self.venv_idx_map.pop(0) # 删除
60
+ tmp = self.venv.pop(1) # 对应的 venv 移入上一层
61
+ self.venv[0] += tmp
62
+ else:
63
+ break
64
+ return name_set
65
 
66
+ def venv_remove(self, name: str):
67
+ if len(self.venv) <= 1:
68
+ return name
69
+ if name not in self.venv_idx_map:
70
+ return name
71
+ venv_idx = self.venv_idx_map.index(name) + 1
72
+ while self.venv_idx_map:
73
+ self.venv_idx_map.pop(venv_idx - 1) # 删除
74
+ if venv_idx == len(self.venv) - 1:
75
+ # 最后一层
76
+ self.n_tokens -= min(self.venv.pop(), self.n_tokens)
77
+ self.kv_cache_seq_trim()
78
+ break
79
+ else:
80
+ # 非最后一层
81
+ n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
82
+ n_discard = self.venv.pop(venv_idx)
83
+ self.kv_cache_seq_ltrim(n_keep, n_discard)
84
+ try:
85
+ venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
86
+ except ValueError: # 没有了
87
+ break
88
+ return name
89
 
90
  def venv_pop_token(self):
91
  self.n_tokens -= 1
92
  self.venv[-1] -= 1
93
  self.kv_cache_seq_trim()
94
 
95
+ @property
96
+ def venv_info(self):
97
+ return str((self.n_tokens, self.venv, self.venv_idx_map))
98
+
99
  def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
100
  if n_past < 0:
101
  n_past = self.n_tokens
 
299
  n_tokens)
300
  self.n_tokens = n_tokens.contents.value
301
  self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
302
+ self._venv_init()
303
  return retn
304
 
305
  def save_session(self, filepath: str):