Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- llama_cpp_python_streamingllm.py +49 -23
llama_cpp_python_streamingllm.py
CHANGED
@@ -31,7 +31,7 @@ def get_complete_UTF8(all_text):
|
|
31 |
class StreamingLLM(Llama):
|
32 |
def __init__(self, model_path: str, **kwargs):
|
33 |
super().__init__(model_path, **kwargs)
|
34 |
-
self.
|
35 |
|
36 |
def str_detokenize(self, tokens) -> str:
|
37 |
return get_complete_UTF8(self.detokenize(tokens))
|
@@ -39,38 +39,63 @@ class StreamingLLM(Llama):
|
|
39 |
def kv_cache_seq_trim(self):
|
40 |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
41 |
|
42 |
-
def
|
|
|
|
|
|
|
|
|
43 |
self.venv.append(0)
|
44 |
-
|
|
|
45 |
|
46 |
-
def venv_disband(self):
|
47 |
if len(self.venv) <= 1:
|
48 |
-
return
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
def venv_remove(self,
|
54 |
-
if
|
55 |
-
|
56 |
-
if
|
57 |
-
return
|
58 |
-
|
59 |
-
|
60 |
-
self.
|
61 |
-
self.
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def venv_pop_token(self):
|
70 |
self.n_tokens -= 1
|
71 |
self.venv[-1] -= 1
|
72 |
self.kv_cache_seq_trim()
|
73 |
|
|
|
|
|
|
|
|
|
74 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
75 |
if n_past < 0:
|
76 |
n_past = self.n_tokens
|
@@ -274,6 +299,7 @@ class StreamingLLM(Llama):
|
|
274 |
n_tokens)
|
275 |
self.n_tokens = n_tokens.contents.value
|
276 |
self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
|
|
|
277 |
return retn
|
278 |
|
279 |
def save_session(self, filepath: str):
|
|
|
31 |
class StreamingLLM(Llama):
|
32 |
def __init__(self, model_path: str, **kwargs):
|
33 |
super().__init__(model_path, **kwargs)
|
34 |
+
self._venv_init()
|
35 |
|
36 |
def str_detokenize(self, tokens) -> str:
|
37 |
return get_complete_UTF8(self.detokenize(tokens))
|
|
|
39 |
def kv_cache_seq_trim(self):
|
40 |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
41 |
|
42 |
+
def _venv_init(self):
|
43 |
+
self.venv = [0]
|
44 |
+
self.venv_idx_map = []
|
45 |
+
|
46 |
+
def venv_create(self, name: str):
|
47 |
self.venv.append(0)
|
48 |
+
self.venv_idx_map.append(name)
|
49 |
+
return name
|
50 |
|
51 |
+
def venv_disband(self, name_set):
|
52 |
if len(self.venv) <= 1:
|
53 |
+
return name_set
|
54 |
+
name_set = {x for x in name_set if x in self.venv_idx_map}
|
55 |
+
if not name_set:
|
56 |
+
return name_set
|
57 |
+
while self.venv_idx_map:
|
58 |
+
if self.venv_idx_map[0] in name_set:
|
59 |
+
self.venv_idx_map.pop(0) # 删除
|
60 |
+
tmp = self.venv.pop(1) # 对应的 venv 移入上一层
|
61 |
+
self.venv[0] += tmp
|
62 |
+
else:
|
63 |
+
break
|
64 |
+
return name_set
|
65 |
|
66 |
+
def venv_remove(self, name: str):
|
67 |
+
if len(self.venv) <= 1:
|
68 |
+
return name
|
69 |
+
if name not in self.venv_idx_map:
|
70 |
+
return name
|
71 |
+
venv_idx = self.venv_idx_map.index(name) + 1
|
72 |
+
while self.venv_idx_map:
|
73 |
+
self.venv_idx_map.pop(venv_idx - 1) # 删除
|
74 |
+
if venv_idx == len(self.venv) - 1:
|
75 |
+
# 最后一层
|
76 |
+
self.n_tokens -= min(self.venv.pop(), self.n_tokens)
|
77 |
+
self.kv_cache_seq_trim()
|
78 |
+
break
|
79 |
+
else:
|
80 |
+
# 非最后一层
|
81 |
+
n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
|
82 |
+
n_discard = self.venv.pop(venv_idx)
|
83 |
+
self.kv_cache_seq_ltrim(n_keep, n_discard)
|
84 |
+
try:
|
85 |
+
venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
|
86 |
+
except ValueError: # 没有了
|
87 |
+
break
|
88 |
+
return name
|
89 |
|
90 |
def venv_pop_token(self):
|
91 |
self.n_tokens -= 1
|
92 |
self.venv[-1] -= 1
|
93 |
self.kv_cache_seq_trim()
|
94 |
|
95 |
+
@property
|
96 |
+
def venv_info(self):
|
97 |
+
return str((self.n_tokens, self.venv, self.venv_idx_map))
|
98 |
+
|
99 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
100 |
if n_past < 0:
|
101 |
n_past = self.n_tokens
|
|
|
299 |
n_tokens)
|
300 |
self.n_tokens = n_tokens.contents.value
|
301 |
self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
|
302 |
+
self._venv_init()
|
303 |
return retn
|
304 |
|
305 |
def save_session(self, filepath: str):
|