Spaces:
Sleeping
Sleeping
themeetjani
commited on
Commit
โข
6060e42
1
Parent(s):
0e4ef18
Upload 10 files
Browse files- pages/AI_Chatbot.py +15 -0
- pages/Auto_Code_Generation.py +14 -0
- pages/Auto_Report_Generation.py +14 -0
- pages/Auto_Score_Generation.py +14 -0
- pages/core_risk.py +135 -0
- pages/jury_records.py +103 -0
- pages/text_clustering.py +14 -0
- pages/topic_classification.py +73 -0
- pages/tweet_classification.py +30 -0
- pages/untitled.txt +0 -0
pages/AI_Chatbot.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="AI_Chatbot.py",
|
5 |
+
page_icon="๐",
|
6 |
+
)
|
7 |
+
|
8 |
+
st.write("# AI Chatbot! ๐")
|
9 |
+
|
10 |
+
st.sidebar.success("Select a demo above.")
|
11 |
+
|
12 |
+
st.markdown(
|
13 |
+
"""
|
14 |
+
**Work in progress!!!** """
|
15 |
+
)
|
pages/Auto_Code_Generation.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="Auto_Code_Generation.py",
|
5 |
+
page_icon="๐",
|
6 |
+
)
|
7 |
+
|
8 |
+
st.write("# Auto Code Generation.! ๐")
|
9 |
+
|
10 |
+
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
**Work in progress!!!** """
|
14 |
+
)
|
pages/Auto_Report_Generation.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="Auto_Report_Generation.py",
|
5 |
+
page_icon="๐",
|
6 |
+
)
|
7 |
+
|
8 |
+
st.write("# Auto Report Generation.! ๐")
|
9 |
+
|
10 |
+
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
**Work in progress!!!** """
|
14 |
+
)
|
pages/Auto_Score_Generation.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="Auto_Score_Generation.py",
|
5 |
+
page_icon="๐",
|
6 |
+
)
|
7 |
+
|
8 |
+
st.write("# Auto Score Generation! ๐")
|
9 |
+
|
10 |
+
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
**Work in progress!!!** """
|
14 |
+
)
|
pages/core_risk.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import streamlit as st
|
5 |
+
from streamlit import session_state
|
6 |
+
import json
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import boto3
|
9 |
+
import pandas as pd
|
10 |
+
bucket = 'data-ai-dev2'
|
11 |
+
from transformers import BertTokenizer, BertModel
|
12 |
+
from torch import cuda
|
13 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
14 |
+
import numpy
|
15 |
+
from numpy.random import seed
|
16 |
+
seed(1)
|
17 |
+
import emoji
|
18 |
+
import string
|
19 |
+
import nltk
|
20 |
+
from nltk.corpus import stopwords
|
21 |
+
from nltk.stem import PorterStemmer # PorterStemmer LancasterStemmer
|
22 |
+
from nltk.stem import WordNetLemmatizer
|
23 |
+
import re
|
24 |
+
stemmer = PorterStemmer()
|
25 |
+
|
26 |
+
# uncomment this when run first time
|
27 |
+
nltk.download('wordnet')
|
28 |
+
nltk.download('omw-1.4')
|
29 |
+
nltk.download('stopwords')
|
30 |
+
|
31 |
+
lemmatizer = WordNetLemmatizer()
|
32 |
+
|
33 |
+
from transformers import pipeline
|
34 |
+
stopwords = nltk.corpus.stopwords.words('english')
|
35 |
+
|
36 |
+
|
37 |
+
model = 'C:/Users/Meet/Downloads/core_risk/models/'
|
38 |
+
tokenizer = 'C:/Users/Meet/Downloads/core_risk/tokenizer/'
|
39 |
+
|
40 |
+
|
41 |
+
from transformers import pipeline
|
42 |
+
|
43 |
+
classifier = pipeline("text-classification", model= model, tokenizer = tokenizer, truncation=True, max_length=512)
|
44 |
+
def pre_processing_str_esg(df_col):
|
45 |
+
df_col = df_col.lower()
|
46 |
+
#defining the function to remove punctuation
|
47 |
+
def remove_punctuation(text):
|
48 |
+
punctuationfree="".join([i for i in text if i not in string.punctuation])
|
49 |
+
return punctuationfree
|
50 |
+
#storing the puntuation free text
|
51 |
+
df_col= remove_punctuation(df_col)
|
52 |
+
df_col = re.sub(r"http\S+", " ", df_col)
|
53 |
+
|
54 |
+
def remove_stopwords(text):
|
55 |
+
return " ".join([word for word in str(text).split() if word not in stopwords])
|
56 |
+
#applying the function
|
57 |
+
df_col = remove_stopwords(df_col)
|
58 |
+
df_col = re.sub('[%s]' % re.escape(string.punctuation), ' ' , df_col)
|
59 |
+
df_col = df_col.replace("ยถ", "")
|
60 |
+
df_col = df_col.replace("ยง", "")
|
61 |
+
df_col = df_col.replace('โ', ' ')
|
62 |
+
df_col = df_col.replace('โ', ' ')
|
63 |
+
df_col = df_col.replace('-', ' ')
|
64 |
+
REPLACE_BY_SPACE_RE = re.compile('[/(){}\[\]\|@,;]')
|
65 |
+
BAD_SYMBOLS_RE = re.compile('[^0-9a-z #+_]')
|
66 |
+
df_col = REPLACE_BY_SPACE_RE.sub(' ',df_col)
|
67 |
+
df_col = BAD_SYMBOLS_RE.sub(' ',df_col)
|
68 |
+
|
69 |
+
# df_col = re.sub('W*dw*','',df_col)
|
70 |
+
df_col = re.sub('[0-9]+', ' ', df_col)
|
71 |
+
df_col = re.sub(' ', ' ', df_col)
|
72 |
+
|
73 |
+
def remove_emoji(string):
|
74 |
+
emoji_pattern = re.compile("["
|
75 |
+
u"\U0001F600-\U0001F64F" # emoticons
|
76 |
+
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
77 |
+
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
78 |
+
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
79 |
+
u"\U00002702-\U000027B0"
|
80 |
+
u"\U000024C2-\U0001F251"
|
81 |
+
"]+", flags=re.UNICODE)
|
82 |
+
return emoji_pattern.sub(r'', string)
|
83 |
+
df_col = remove_emoji(df_col)
|
84 |
+
|
85 |
+
return df_col
|
86 |
+
|
87 |
+
def pre_processing_str(df_col):
|
88 |
+
# df_col = df_col.lower()
|
89 |
+
if len(df_col.split()) >= 70:
|
90 |
+
return pre_processing_str_esg(df_col)
|
91 |
+
else:
|
92 |
+
df_col = df_col.replace('#', '')
|
93 |
+
df_col = df_col.replace('!', '')
|
94 |
+
df_col = re.sub(r"http\S+", " ", df_col)
|
95 |
+
|
96 |
+
df_col = re.sub('[0-9]+', ' ', df_col)
|
97 |
+
df_col = re.sub(' ', ' ', df_col)
|
98 |
+
def remove_emojis(text):
|
99 |
+
return emoji.replace_emoji(text)
|
100 |
+
df_col = remove_emojis(df_col)
|
101 |
+
df_col = re.sub(r"(?:\@|https?\://)\S+", "", df_col)
|
102 |
+
df_col = re.sub(r"[^\x20-\x7E]+", "", df_col)
|
103 |
+
df_col = df_col.strip()
|
104 |
+
return df_col
|
105 |
+
|
106 |
+
|
107 |
+
# start for the api steps make sure name should me match with file name and application = Flask(__name__). 'application.py and application
|
108 |
+
|
109 |
+
def process(text):
|
110 |
+
text = pre_processing_str(text)
|
111 |
+
|
112 |
+
try:
|
113 |
+
if len(text) != 0:
|
114 |
+
results = classifier(text, top_k = 2)
|
115 |
+
else:
|
116 |
+
results = 'No Text'
|
117 |
+
|
118 |
+
return {'output_16':results}
|
119 |
+
except:
|
120 |
+
return {'output_16':'something went wrong'}
|
121 |
+
|
122 |
+
st.set_page_config(page_title="core_risk", page_icon="๐")
|
123 |
+
if 'topic_class' not in session_state:
|
124 |
+
session_state['topic_class']= ""
|
125 |
+
|
126 |
+
st.title("Topic Classifier")
|
127 |
+
text= st.text_area(label= "Please write the text bellow",
|
128 |
+
placeholder="What does the tweet say?")
|
129 |
+
def classify(text):
|
130 |
+
session_state['topic_class'] = process(text)
|
131 |
+
|
132 |
+
|
133 |
+
st.text_area("result", value=session_state['topic_class'])
|
134 |
+
|
135 |
+
st.button("Classify", on_click=classify, args=[text])
|
pages/jury_records.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#import the necessary packages
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit import session_state
|
4 |
+
from langchain.document_loaders import WebBaseLoader, PyPDFLoader, TextLoader
|
5 |
+
from langchain.indexes import VectorstoreIndexCreator
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
+
from langchain.docstore.document import Document
|
8 |
+
import os
|
9 |
+
from langchain.chat_models import ChatOpenAI
|
10 |
+
import openai
|
11 |
+
import json
|
12 |
+
#need to set openai key or set it as a environment variable
|
13 |
+
openai.api_key = "give api key"
|
14 |
+
model = ChatOpenAI(model = 'gpt-4', max_tokens = 100,temperature=0)
|
15 |
+
st.set_page_config(page_title="jury_records", page_icon="๐")
|
16 |
+
# using this function to extract the content from the url. here we are using langchain webbaseloader to extract the content. We can use any web scrapping function also.
|
17 |
+
def extract(link):
|
18 |
+
res = []
|
19 |
+
loader = WebBaseLoader(link)
|
20 |
+
pages = loader.load()
|
21 |
+
for i in pages:
|
22 |
+
res.append(i.page_content.replace('\n',''))
|
23 |
+
a = " ".join(res)
|
24 |
+
print(len(a))
|
25 |
+
if len(a)>0:
|
26 |
+
return a
|
27 |
+
else:
|
28 |
+
return 'error'
|
29 |
+
|
30 |
+
#Summarize the content with use of gpt4 with prompting.
|
31 |
+
def summarize(link):
|
32 |
+
context = extract(link)
|
33 |
+
if context!='error':
|
34 |
+
#print(context)
|
35 |
+
response = openai.ChatCompletion.create(
|
36 |
+
model="gpt-4",
|
37 |
+
messages=[
|
38 |
+
{
|
39 |
+
"role": "system",
|
40 |
+
"content": f"Following context is given.{context}" },
|
41 |
+
{
|
42 |
+
"role": "user",
|
43 |
+
"content": '''Summarize the content in detail. Follow these instructions while summarizing.\n Include case no.\n Include all Plaintiff. \n Include the court name.
|
44 |
+
\n Alias name should be included.\n Include case no. \n Include all defendants.\n If place is mentioned then include it, otherwise don't include it.
|
45 |
+
\n Date format should be dd/mm/yyyy.\n If case is settled for an amount then try to include the amount.
|
46 |
+
If amount is not mentioned don't mentioned anything about the same. only include this line if case is
|
47 |
+
setteled otherwise include the status of case.\n\n<<REMEMBER>>\n\n Please try to include all the details. Don't leave out any information.'''
|
48 |
+
}
|
49 |
+
],
|
50 |
+
temperature=0,
|
51 |
+
max_tokens=1000,
|
52 |
+
top_p=1,
|
53 |
+
frequency_penalty=0,
|
54 |
+
presence_penalty=0
|
55 |
+
)
|
56 |
+
return response.choices[0].message.content.strip()
|
57 |
+
else:
|
58 |
+
return 'error'
|
59 |
+
|
60 |
+
# Passing these questions dictinary for qna. there are lot of iterations has been done and this is final questions dictionary that we have come up with. you can change this dictionary based on input parameters those needs to be extracted from url.
|
61 |
+
info_detail = {'case_type':'provide case type or court system like "Criminal", "Family Law", "labour law"',
|
62 |
+
'name_of_court': 'provide name of court or jail or court record.',
|
63 |
+
'case_number': 'provide case number or country case number or bankrupty case number', 'date_filed': 'what is the date when the case was filed or the date when case first formally/officially submitted?',
|
64 |
+
'plaintiff': 'Names of the Petitioner or plaintiff or applicant? ',
|
65 |
+
'defendants': "Names of all defendants, respondent and alias. Name entity under 'Defendants'",
|
66 |
+
'nature_of_action': 'Summarize the reason behind the case within 20 words in detail',
|
67 |
+
'status': 'what is the status of case?'}
|
68 |
+
|
69 |
+
#langchain function for qna over the summary extracted from gpt4. vector database concept has been adopted.
|
70 |
+
def lang(context):
|
71 |
+
answer_dict={}
|
72 |
+
docs = Document(page_content=context)
|
73 |
+
index2 = VectorstoreIndexCreator().from_documents([docs])
|
74 |
+
for key in info_detail:
|
75 |
+
ques = info_detail[key]
|
76 |
+
answer_dict[key] = index2.query(llm = model, question = ques)
|
77 |
+
index2.vectorstore.delete_collection()
|
78 |
+
return answer_dict
|
79 |
+
|
80 |
+
def process(url):
|
81 |
+
try:
|
82 |
+
summary = summarize(url)
|
83 |
+
if summary == 'error':
|
84 |
+
return {"details":"","status":False}
|
85 |
+
else:
|
86 |
+
answer_dict = lang(summary)
|
87 |
+
return answer_dict
|
88 |
+
except:
|
89 |
+
return "Please try again"
|
90 |
+
if 'jury_records_dict' not in session_state:
|
91 |
+
session_state['jury_records_dict']= ""
|
92 |
+
|
93 |
+
def Jury(url):
|
94 |
+
session_state['jury_records_dict']= process(jury_url)
|
95 |
+
|
96 |
+
st.title("Jury Records")
|
97 |
+
|
98 |
+
jury_url= st.text_area(label= "Please enter the jury records link",
|
99 |
+
placeholder="Jury records Link")
|
100 |
+
|
101 |
+
st.text_area("result", value=session_state['jury_records_dict'])
|
102 |
+
|
103 |
+
st.button("Get answer dictionary", on_click=Jury, args=[jury_url])
|
pages/text_clustering.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="text_clustering.py",
|
5 |
+
page_icon="๐",
|
6 |
+
)
|
7 |
+
|
8 |
+
st.write("# Text Clustering.! ๐")
|
9 |
+
|
10 |
+
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
**Work in progress!!!** """
|
14 |
+
)
|
pages/topic_classification.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#importing all the neccesary packages here
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit import session_state
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
from scipy import spatial
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
import json
|
9 |
+
|
10 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') #calling hugging face model for embeddings here
|
11 |
+
#cosine function for
|
12 |
+
def cosine_similarity(x,y):
|
13 |
+
return 1 - spatial.distance.cosine(x,y)
|
14 |
+
|
15 |
+
# reading topic file into dataframe
|
16 |
+
df = pd.read_excel(r'C:\Users\Meet\Downloads/topic_data.xlsx')
|
17 |
+
#df2 = pd.read_csv("BBC News Train.csv") #sample news article file
|
18 |
+
#storing level1 and level2 segments into dictinary first
|
19 |
+
result_dict = df.groupby('LEVEL 1')['new_level_2'].apply(list).to_dict()
|
20 |
+
#storing l1 segments
|
21 |
+
segments = list(result_dict.keys())
|
22 |
+
segments_encode = model.encode(segments) #encoding l1 segments with model
|
23 |
+
#creating embedding dictionary of all l1 segments and l2 segments.
|
24 |
+
#embedding dictionary for l2 segments
|
25 |
+
embeddings_dict = {}
|
26 |
+
for key, val in result_dict.items():
|
27 |
+
embed = model.encode(result_dict[key])
|
28 |
+
embeddings_dict[key] = embed
|
29 |
+
|
30 |
+
#function for calculating l1 segments.
|
31 |
+
def segments_finder(text_encode):
|
32 |
+
score_dict = {}
|
33 |
+
for segment,name in zip(segments_encode,segments):
|
34 |
+
similarity_score = cosine_similarity(segment,text_encode)
|
35 |
+
score_dict[name] = similarity_score
|
36 |
+
return sorted(score_dict.items(), key=lambda x: x[1], reverse=True)
|
37 |
+
|
38 |
+
def level2(article_summary):
|
39 |
+
l1 = {}
|
40 |
+
l2 = {}
|
41 |
+
output = {}
|
42 |
+
text_encode = model.encode(article_summary)
|
43 |
+
l1_pred = segments_finder(text_encode)
|
44 |
+
#iterating in l1 segments to find their l2 segments.
|
45 |
+
for i in l1_pred[:2]:
|
46 |
+
score_dict = {}
|
47 |
+
l2_segments = result_dict[i[0]]
|
48 |
+
l2_segments_encode = embeddings_dict[i[0]]
|
49 |
+
for segment,name in zip(l2_segments_encode,l2_segments):
|
50 |
+
similarity_score = cosine_similarity(segment,text_encode)
|
51 |
+
score_dict[name] = similarity_score
|
52 |
+
l2_pred = dict(list(sorted(score_dict.items(), key=lambda x: x[1], reverse=True))[:2])
|
53 |
+
print(l2_pred)
|
54 |
+
l2[i[0]] = l2_pred
|
55 |
+
output['l1'] = dict(list(sorted(dict(l1_pred).items(), key=lambda x: x[1], reverse=True))[:2])
|
56 |
+
output['l2'] = l2
|
57 |
+
return output
|
58 |
+
|
59 |
+
st.set_page_config(page_title="topic_classification", page_icon="๐")
|
60 |
+
|
61 |
+
if 'topic_class' not in session_state:
|
62 |
+
session_state['topic_class']= ""
|
63 |
+
|
64 |
+
st.title("Topic Classifier")
|
65 |
+
text= st.text_area(label= "Please write the text bellow",
|
66 |
+
placeholder="What does the tweet say?")
|
67 |
+
def classify(text):
|
68 |
+
session_state['topic_class'] = level2(text)
|
69 |
+
|
70 |
+
|
71 |
+
st.text_area("result", value=session_state['topic_class'])
|
72 |
+
|
73 |
+
st.button("Classify", on_click=classify, args=[text])
|
pages/tweet_classification.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit import session_state
|
3 |
+
# Load model directly
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
from transformers import pipeline
|
6 |
+
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained("themeetjani/tweet-classification")
|
8 |
+
model = AutoModelForSequenceClassification.from_pretrained("themeetjani/tweet-classification")
|
9 |
+
|
10 |
+
classifier = pipeline("text-classification", model= model, tokenizer = tokenizer, truncation=True, max_length=512)
|
11 |
+
|
12 |
+
st.set_page_config(page_title="Classification", page_icon="๐")
|
13 |
+
if 'tweet_class' not in session_state:
|
14 |
+
session_state['tweet_class']= ""
|
15 |
+
|
16 |
+
def classify(tweet):
|
17 |
+
predicted_classes= session_state['tweet_class']= classifier(tweet, top_k=1)
|
18 |
+
print (tweet)
|
19 |
+
print (predicted_classes)
|
20 |
+
session_state['tweet_class'] = predicted_classes[0]['label']
|
21 |
+
|
22 |
+
st.title("Tweet Classifier")
|
23 |
+
|
24 |
+
tweet= st.text_area(label= "Please write the tweet bellow",
|
25 |
+
placeholder="What does the tweet say?")
|
26 |
+
|
27 |
+
st.text_area("result", value=session_state['tweet_class'])
|
28 |
+
|
29 |
+
st.button("Classify", on_click=classify, args=[tweet])
|
30 |
+
|
pages/untitled.txt
ADDED
File without changes
|