def read_and_split_file(filename, chunk_size=1200, chunk_overlap=200): with open(filename, 'r') as f: text = f.read() text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function = len, separators=[" ", ",", "\n"] ) # st.write(f'Financial report char len: {len(text)}') texts = text_splitter.create_documents([text]) return texts if __name__ == '__main__': # Comments and ideas to implement: # 1. Try sending list of inputs to the Inference API. import streamlit as st from sys import exit from pprint import pprint from collections import Counter from itertools import zip_longest from random import choice import requests from re import sub from rouge import Rouge from time import sleep, perf_counter import os from textwrap import wrap from multiprocessing import Pool, freeze_support from tqdm import tqdm from stqdm import stqdm from langchain.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema.document import Document # from langchain.schema import Document from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain.prompts import PromptTemplate from datasets import Dataset, load_dataset from sklearn.preprocessing import LabelEncoder from test_models.train_classificator import MLP from safetensors.torch import load_model, save_model from sentence_transformers import SentenceTransformer from torch.utils.data import DataLoader, TensorDataset import torch.nn.functional as F import torch import torch.nn as nn import sys sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/'))) sys.path.append(os.path.abspath(os.path.join(os.getcwd(), 'test_models/financial-roberta'))) st.set_page_config( page_title="Financial advisor", page_icon="๐Ÿ’ณ๐Ÿ’ฐ", layout="wide", ) # st.session_state.summarized = False with st.sidebar: "# How to use๐Ÿ”" """ โœจThis is a holiday version of the web-UI with the magic ๐ŸŒ, allowing you to unwrap label predictions for a company based on its financial report text! ๐Ÿ“Šโœจ The prediction enchantment is performed using the sophisticated embedding classifier approach. ๐Ÿš€๐Ÿ”ฎ """ center_style = "

{}

" st.markdown(center_style.format('Load the financial report'), unsafe_allow_html=True) upload_types = ['Text input', 'File upload'] upload_captions = ['Paste the text', 'Upload a text file'] upload_type = st.radio('Select how to upload the financial report', upload_types, captions=upload_captions) match upload_type: case 'Text input': financial_report_text = st.text_area('Something', label_visibility='collapsed', placeholder='Financial report as TEXT') case 'File upload': uploaded_files = st.file_uploader("Choose a a text file", type=['.txt', '.docx'], label_visibility='collapsed', accept_multiple_files=True) if not bool(uploaded_files): st.stop() financial_report_text = '' for uploaded_file in uploaded_files: if uploaded_file.name.endswith("docx"): document = Document(uploaded_file) document.save('./utils/texts/' + uploaded_file.name) document = Document(uploaded_file.name) financial_report_text += "".join([paragraph.text for paragraph in document.paragraphs]) + '\n' else: financial_report_text += "".join([line.decode() for line in uploaded_file]) + '\n' # with open('./utils/texts/financial_report_text.txt', 'w') as file: # file.write(financial_report_text) if st.button('Get label'): with st.spinner("Thinking..."): text_splitter = RecursiveCharacterTextSplitter( chunk_size=3200, chunk_overlap=200, length_function = len, separators=[" ", ",", "\n"] ) # st.write(f'Financial report char len: {len(financial_report_text)}') documents = text_splitter.create_documents([financial_report_text]) # st.write(f'Num chunks: {len(documents)}') texts = [document.page_content for document in documents] # st.write(f'Each chunk char length: {[len(text) for text in texts]}') # predicted_label = get_label_prediction(texts) from test_models.create_setfit_model import model with torch.no_grad(): model.model_head.eval() predicted_labels = model(texts) # st.write(predicted_labels) predicted_labels_counter = Counter(predicted_labels) predicted_label = predicted_labels_counter.most_common(1)[0][0] font_style = 'The predicted label is **{}**.' st.markdown(font_style.format(predicted_label), unsafe_allow_html=True)