|
from langchain.docstore.document import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from logging import getLogger |
|
|
|
logger = getLogger(__name__) |
|
|
|
def get_input_token_count(text:str,tokenizer)->int: |
|
tokens = tokenizer.tokenize(text) |
|
return len(tokens) |
|
|
|
def get_document_splits_from_text(text:str) -> Document: |
|
document = Document(page_content=text) |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
separators=["\n\n","\n",".","?"," "], |
|
chunk_size=15000, |
|
chunk_overlap = 50 |
|
) |
|
split_documents = text_splitter.split_documents([document]) |
|
logger.info(f"Splitting Document: Total Chunks: {len(split_documents)} ") |
|
return split_documents |
|
|
|
|
|
def prepare_for_summarize(text:str,tokenizer): |
|
no_input_tokens = get_input_token_count(text,tokenizer) |
|
if no_input_tokens<12000: |
|
text_to_summarize = text |
|
length_type = "short" |
|
return text_to_summarize,length_type |
|
else: |
|
text_to_summarize = get_document_splits_from_text(text) |
|
length_type = "long" |
|
|
|
return text_to_summarize, length_type |
|
|