starsaround commited on
Commit
5071898
1 Parent(s): 5238878

Upload memory_func.py

Browse files

deal with memory capacity

Files changed (1) hide show
  1. models_for_langchain/memory_func.py +26 -0
models_for_langchain/memory_func.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.llms.base import LLM
2
+ from langchain.memory import ConversationBufferWindowMemory
3
+ from transformers import GPT2TokenizerFast
4
+ from langchain.schema.messages import get_buffer_string
5
+
6
+ def get_num_tokens(text):
7
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
8
+ return len(tokenizer.tokenize(text))
9
+
10
+ def get_memory_num_tokens(memory):
11
+ buffer = memory.chat_memory.messages
12
+ return sum([get_num_tokens(get_buffer_string([m])) for m in buffer])
13
+
14
+ def validate_memory_len(memory, max_token_limit=2000):
15
+ buffer = memory.chat_memory.messages
16
+ curr_buffer_length = get_memory_num_tokens(memory)
17
+ if curr_buffer_length > max_token_limit:
18
+ while curr_buffer_length > max_token_limit:
19
+ buffer.pop(0)
20
+ curr_buffer_length = get_memory_num_tokens(memory)
21
+ return memory
22
+
23
+ if __name__ == '__main__':
24
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
25
+ text = '''Hi'''
26
+ print(len(tokenizer.tokenize(text)))