Spaces:
Build error
Build error
app_added
Browse files- .ipynb_checkpoints/app-checkpoint.ipynb +6 -0
- Sentiment_Analysis_Bert.ipynb +606 -0
- app.ipynb +174 -0
- roberta_sentiment_analysis.py +137 -0
.ipynb_checkpoints/app-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
Sentiment_Analysis_Bert.ipynb
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "c415192a",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"pip install datasets transformers huggingface_hub\n",
|
11 |
+
"apt-get install git-lfs"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 1,
|
17 |
+
"id": "bd861c86",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [
|
20 |
+
{
|
21 |
+
"data": {
|
22 |
+
"text/html": [
|
23 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
|
24 |
+
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\"><module></span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">14</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
|
25 |
+
"<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
|
26 |
+
"<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">ModuleNotFoundError: </span>No module named <span style=\"color: #008000; text-decoration-color: #008000\">'google.colab'</span>\n",
|
27 |
+
"</pre>\n"
|
28 |
+
],
|
29 |
+
"text/plain": [
|
30 |
+
"\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
|
31 |
+
"\u001b[31m│\u001b[0m in \u001b[92m<module>\u001b[0m:\u001b[94m14\u001b[0m \u001b[31m│\u001b[0m\n",
|
32 |
+
"\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
|
33 |
+
"\u001b[1;91mModuleNotFoundError: \u001b[0mNo module named \u001b[32m'google.colab'\u001b[0m\n"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"metadata": {},
|
37 |
+
"output_type": "display_data"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"source": [
|
41 |
+
"import pandas as pd\n",
|
42 |
+
"import numpy as np\n",
|
43 |
+
"\n",
|
44 |
+
"#from datasets import load_metric\n",
|
45 |
+
"import os\n",
|
46 |
+
"import matplotlib.pyplot as plt\n",
|
47 |
+
"import wordcloud\n",
|
48 |
+
"from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator\n",
|
49 |
+
"from sklearn.model_selection import train_test_split\n",
|
50 |
+
"from datasets import load_dataset, load_metric\n",
|
51 |
+
"import io\n",
|
52 |
+
"from transformers import AutoTokenizer, TrainingArguments, Trainer\n",
|
53 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
|
54 |
+
"from google.colab import files\n",
|
55 |
+
"from google.colab import drive\n",
|
56 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
57 |
+
"\n",
|
58 |
+
"#converting training data to PyTorch tensors to speed up training and adding padding:\n",
|
59 |
+
"from transformers import DataCollatorWithPadding\n",
|
60 |
+
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
61 |
+
"from huggingface_hub import notebook_login"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
+
"id": "8b2e212f",
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"# Disabe W&B\n",
|
72 |
+
"os.environ[\"WANDB_DISABLED\"] = \"true\""
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "b5d8f134",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"drive.mount('/content/drive')"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "markdown",
|
87 |
+
"id": "705a32e0",
|
88 |
+
"metadata": {},
|
89 |
+
"source": [
|
90 |
+
"## Data set importation"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": null,
|
96 |
+
"id": "d7e95a22",
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"train =pd.read_csv(\"/content/drive/MyDrive/PostBAP_ASSESSMENT/hugging.csv\").dropna(axis = 0)\n",
|
101 |
+
"test = pd.read_csv(\"/content/drive/MyDrive/PostBAP_ASSESSMENT/Testhugging.csv\").fillna(\"\")\n"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "markdown",
|
106 |
+
"id": "bb162c55",
|
107 |
+
"metadata": {},
|
108 |
+
"source": [
|
109 |
+
"## Data cleaning "
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": null,
|
115 |
+
"id": "3c676772",
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"train.head()"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": null,
|
125 |
+
"id": "37512519",
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"train.info()"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": null,
|
135 |
+
"id": "7f922334",
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"train.shape"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": null,
|
145 |
+
"id": "8ec92e68",
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"train.dtypes"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": null,
|
155 |
+
"id": "c2179d7a",
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [],
|
158 |
+
"source": [
|
159 |
+
"train.isnull().sum()"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"id": "eebc84d7",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"#Removing missing values \n",
|
170 |
+
"train.dropna(inplace = True)"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": null,
|
176 |
+
"id": "f2fe9dc5",
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [],
|
179 |
+
"source": [
|
180 |
+
"train.isnull().sum()"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": null,
|
186 |
+
"id": "b98c837e",
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [],
|
189 |
+
"source": [
|
190 |
+
"# Remove customer IDs from the data set\n",
|
191 |
+
"train = train.iloc[:,1:]"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"execution_count": null,
|
197 |
+
"id": "3bb1a46b",
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"train_data.head()"
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "markdown",
|
206 |
+
"id": "47daf33f",
|
207 |
+
"metadata": {},
|
208 |
+
"source": [
|
209 |
+
"## EXPLORATORY DATA ANALYSIS"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "markdown",
|
214 |
+
"id": "57bc229c",
|
215 |
+
"metadata": {},
|
216 |
+
"source": [
|
217 |
+
"\n",
|
218 |
+
"Class Distribution Analysis:\n",
|
219 |
+
"\n",
|
220 |
+
"Examine the distribution of the \"label\" column to understand the balance between different classes (e.g., positive, negative, neutral, etc.). Visualize this distribution using bar plots or pie charts.\n",
|
221 |
+
"Agreement Analysis:\n",
|
222 |
+
"\n",
|
223 |
+
"Investigate the \"agreement\" column to understand the level of agreement between annotators or classifiers. Analyze the distribution of agreement scores and determine if there are discrepancies or inconsistencies. Visualize this distribution using histograms or box plots.\n",
|
224 |
+
"Text Length Analysis:\n",
|
225 |
+
"\n",
|
226 |
+
"Explore the length of the \"safe_text\" column (number of characters or words). Calculate summary statistics such as mean, median, minimum, maximum, and standard deviation. Plot histograms or box plots to visualize the distribution of text lengths.\n",
|
227 |
+
"Word Frequency Analysis:\n",
|
228 |
+
"\n",
|
229 |
+
"Identify the most frequent words or terms in the \"safe_text\" column. Generate a word cloud or bar chart to visualize the top N words. This analysis can provide insights into the most commonly used language in the dataset.\n",
|
230 |
+
"Class Distribution Analysis\n",
|
231 |
+
"\n",
|
232 |
+
"Examine the distribution of the \"label\" column to understand the balance between different classes (e.g., positive, negative, neutral)"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": null,
|
238 |
+
"id": "ff2f8361",
|
239 |
+
"metadata": {},
|
240 |
+
"outputs": [],
|
241 |
+
"source": []
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "markdown",
|
245 |
+
"id": "f3877ad2",
|
246 |
+
"metadata": {},
|
247 |
+
"source": [
|
248 |
+
"Class Distribution Analysis\n",
|
249 |
+
"\n",
|
250 |
+
"Examine the distribution of the \"label\" column to understand the balance between different classes (e.g., positive, negative, neutral)"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": null,
|
256 |
+
"id": "6f4ca78f",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"ax = (train_data['label'].value_counts()*100.0 /len(train_data))\\\n",
|
261 |
+
".plot.pie(autopct='%.1f%%', labels = ['Neutral', 'Positive', 'Negative'],figsize =(5,5), fontsize = 12 ) \n",
|
262 |
+
"ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n",
|
263 |
+
"ax.set_ylabel('',fontsize = 12)\n",
|
264 |
+
"ax.set_title('Distribution of tweet', fontsize = 12)\n",
|
265 |
+
"\n",
|
266 |
+
"plt.show()"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "markdown",
|
271 |
+
"id": "3cbe8f79",
|
272 |
+
"metadata": {},
|
273 |
+
"source": [
|
274 |
+
"Agreement Analysis:\n",
|
275 |
+
"\n",
|
276 |
+
"Investigate the \"agreement\" column to understand the level of agreement between annotators or classifiers. Analyze the distribution of agreement scores and determine if there are discrepancies or inconsistencies."
|
277 |
+
]
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "code",
|
281 |
+
"execution_count": null,
|
282 |
+
"id": "5c76a978",
|
283 |
+
"metadata": {},
|
284 |
+
"outputs": [],
|
285 |
+
"source": [
|
286 |
+
"ax = train_data['agreement'].value_counts().plot(kind = 'bar',rot = 45, width = 0.3)\n",
|
287 |
+
"ax.set_ylabel('Agreement')\n",
|
288 |
+
"ax.set_title('Agreement Distribution')\n",
|
289 |
+
"\n",
|
290 |
+
"plt.show()"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"cell_type": "markdown",
|
295 |
+
"id": "49714fb9",
|
296 |
+
"metadata": {},
|
297 |
+
"source": [
|
298 |
+
"Text Length Analysis\n",
|
299 |
+
"\n",
|
300 |
+
"Explore the length of the \"safe_text\" column (number of characters or words)."
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"execution_count": null,
|
306 |
+
"id": "e9020da9",
|
307 |
+
"metadata": {},
|
308 |
+
"outputs": [],
|
309 |
+
"source": [
|
310 |
+
"# Calculate the length of each text\n",
|
311 |
+
"train_data['text_length'] = train_data['safe_text'].apply(lambda x: len(x))\n",
|
312 |
+
"\n",
|
313 |
+
"# Plot the text length distribution\n",
|
314 |
+
"plt.hist(train_data['text_length'], bins=20)\n",
|
315 |
+
"plt.xlabel('Text Length')\n",
|
316 |
+
"plt.ylabel('Count')\n",
|
317 |
+
"plt.title('Text Length Distribution')\n",
|
318 |
+
"plt.show()"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "markdown",
|
323 |
+
"id": "126e79c7",
|
324 |
+
"metadata": {},
|
325 |
+
"source": [
|
326 |
+
"Word Frequency Analysis\n",
|
327 |
+
"\n",
|
328 |
+
"Identify the most frequent words or terms in the \"safe_text\" column. This analysis can provide insights into the most commonly used language in the dataset."
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "code",
|
333 |
+
"execution_count": null,
|
334 |
+
"id": "933690da",
|
335 |
+
"metadata": {},
|
336 |
+
"outputs": [],
|
337 |
+
"source": [
|
338 |
+
"\n",
|
339 |
+
"# Combine all the texts into a single string\n",
|
340 |
+
"all_text = ' '.join(train_data['safe_text'])\n",
|
341 |
+
"\n",
|
342 |
+
"# Tokenize the text into individual words\n",
|
343 |
+
"tokens = word_tokenize(all_text)\n",
|
344 |
+
"\n",
|
345 |
+
"# Count the frequency of each word\n",
|
346 |
+
"word_freq = Counter(tokens)\n",
|
347 |
+
"\n",
|
348 |
+
"# Get the top 10 most used words\n",
|
349 |
+
"top_10_words = word_freq.most_common(10)\n",
|
350 |
+
"\n",
|
351 |
+
"# Print the top 10 words and their frequencies\n",
|
352 |
+
"for word, freq in top_10_words:\n",
|
353 |
+
" print(f\"Word: {word}\\tFrequency: {freq}\")"
|
354 |
+
]
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"cell_type": "code",
|
358 |
+
"execution_count": null,
|
359 |
+
"id": "f64e638e",
|
360 |
+
"metadata": {},
|
361 |
+
"outputs": [],
|
362 |
+
"source": [
|
363 |
+
"# Combine all the texts into a single string\n",
|
364 |
+
"all_text = ' '.join(train_data['safe_text'])\n",
|
365 |
+
"\n",
|
366 |
+
"# Tokenize the text into individual words\n",
|
367 |
+
"tokens = word_tokenize(all_text)\n",
|
368 |
+
"\n",
|
369 |
+
"# Count the frequency of each word\n",
|
370 |
+
"word_freq = Counter(tokens)\n",
|
371 |
+
"\n",
|
372 |
+
"# Get the top 10 most used words and their frequencies\n",
|
373 |
+
"top_10_words = word_freq.most_common(10)\n",
|
374 |
+
"\n",
|
375 |
+
"# Extract the words and frequencies for plotting\n",
|
376 |
+
"words = [word for word, freq in top_10_words]\n",
|
377 |
+
"frequencies = [freq for word, freq in top_10_words]\n",
|
378 |
+
"\n",
|
379 |
+
"# Create a bar plot\n",
|
380 |
+
"plt.figure(figsize=(10, 6))\n",
|
381 |
+
"plt.bar(words, frequencies)\n",
|
382 |
+
"plt.xlabel('Words')\n",
|
383 |
+
"plt.ylabel('Frequency')\n",
|
384 |
+
"plt.title('Top 10 Most Used Words')\n",
|
385 |
+
"plt.xticks(rotation=45)\n",
|
386 |
+
"plt.tight_layout()\n",
|
387 |
+
"plt.show()"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"cell_type": "code",
|
392 |
+
"execution_count": null,
|
393 |
+
"id": "72d98f6b",
|
394 |
+
"metadata": {},
|
395 |
+
"outputs": [],
|
396 |
+
"source": [
|
397 |
+
"# modeling "
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": null,
|
403 |
+
"id": "c955af2a",
|
404 |
+
"metadata": {},
|
405 |
+
"outputs": [],
|
406 |
+
"source": [
|
407 |
+
"train, eval = train_test_split(train, test_size=0.2, random_state=42, stratify= train['label'])"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"execution_count": null,
|
413 |
+
"id": "df8ba692",
|
414 |
+
"metadata": {},
|
415 |
+
"outputs": [],
|
416 |
+
"source": [
|
417 |
+
"print(f\"new dataframe shapes: train is {train.shape}, eval is {eval.shape}\")"
|
418 |
+
]
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"cell_type": "code",
|
422 |
+
"execution_count": null,
|
423 |
+
"id": "1b7e04ed",
|
424 |
+
"metadata": {},
|
425 |
+
"outputs": [],
|
426 |
+
"source": [
|
427 |
+
"# Save splitted subsets\n",
|
428 |
+
"train.to_csv(\"/content/drive/MyDrive/PostBAP_ASSESSMENT/train_subset.csv\", index=False)\n",
|
429 |
+
"eval.to_csv(\"/content/drive/MyDrive/PostBAP_ASSESSMENT/eval_subset.csv\", index=False)"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": null,
|
435 |
+
"id": "9590e1f2",
|
436 |
+
"metadata": {},
|
437 |
+
"outputs": [],
|
438 |
+
"source": [
|
439 |
+
"dataset = load_dataset('csv',\n",
|
440 |
+
" data_files={'train': '/content/drive/MyDrive/PostBAP_ASSESSMENT/train_subset.csv',\n",
|
441 |
+
" 'eval': '/content/drive/MyDrive/PostBAP_ASSESSMENT/eval_subset.csv'}, encoding = \"ISO-8859-1\")"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": null,
|
447 |
+
"id": "c27314c1",
|
448 |
+
"metadata": {},
|
449 |
+
"outputs": [],
|
450 |
+
"source": [
|
451 |
+
"def transform_labels(label):\n",
|
452 |
+
"\n",
|
453 |
+
" label = label['label']\n",
|
454 |
+
" num = 0\n",
|
455 |
+
" if label == -1: #'Negative'\n",
|
456 |
+
" num = 0\n",
|
457 |
+
" elif label == 0: #'Neutral'\n",
|
458 |
+
" num = 1\n",
|
459 |
+
" elif label == 1: #'Positive'\n",
|
460 |
+
" num = 2\n",
|
461 |
+
"\n",
|
462 |
+
" return {'labels': num}\n"
|
463 |
+
]
|
464 |
+
},
|
465 |
+
{
|
466 |
+
"cell_type": "code",
|
467 |
+
"execution_count": null,
|
468 |
+
"id": "5a39bf13",
|
469 |
+
"metadata": {},
|
470 |
+
"outputs": [],
|
471 |
+
"source": [
|
472 |
+
"def tokenize_data(example):\n",
|
473 |
+
" return tokenizer(example['safe_text'], padding='max_length')"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"cell_type": "code",
|
478 |
+
"execution_count": null,
|
479 |
+
"id": "e6e85205",
|
480 |
+
"metadata": {},
|
481 |
+
"outputs": [],
|
482 |
+
"source": [
|
483 |
+
"# Change the tweets to tokens that the models can exploit\n",
|
484 |
+
"dataset = dataset.map(tokenize_data, batched=True)\n",
|
485 |
+
"\n",
|
486 |
+
"# Transform\tlabels and remove the useless columns\n",
|
487 |
+
"remove_columns = ['tweet_id', 'label', 'safe_text', 'agreement']\n",
|
488 |
+
"dataset = dataset.map(transform_labels, remove_columns=remove_columns)\n"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": null,
|
494 |
+
"id": "229f9b59",
|
495 |
+
"metadata": {},
|
496 |
+
"outputs": [],
|
497 |
+
"source": [
|
498 |
+
"training_args = TrainingArguments(\"test_trainer\", num_train_epochs= 10, load_best_model_at_end=True,evaluation_strategy= \"steps\")\n",
|
499 |
+
"\n",
|
500 |
+
"model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=3)\n",
|
501 |
+
"\n",
|
502 |
+
"train_dataset = dataset['train'].shuffle(seed=20) #.select(range(40000)) # to select a part\n",
|
503 |
+
"eval_dataset = dataset['eval'].shuffle(seed=20)\n",
|
504 |
+
"\n",
|
505 |
+
"trainer = Trainer(\n",
|
506 |
+
" model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset\n",
|
507 |
+
")\n",
|
508 |
+
"\n",
|
509 |
+
"trainer.train()"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"cell_type": "code",
|
514 |
+
"execution_count": null,
|
515 |
+
"id": "6824dd84",
|
516 |
+
"metadata": {},
|
517 |
+
"outputs": [],
|
518 |
+
"source": [
|
519 |
+
"#defining the evaluation metrics\n",
|
520 |
+
"metric = load_metric(\"accuracy\")"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"cell_type": "code",
|
525 |
+
"execution_count": null,
|
526 |
+
"id": "03b3c80f",
|
527 |
+
"metadata": {},
|
528 |
+
"outputs": [],
|
529 |
+
"source": [
|
530 |
+
"def compute_metrics(eval_pred):\n",
|
531 |
+
" logits, labels = eval_pred\n",
|
532 |
+
" predictions = np.argmax(logits, axis=-1)\n",
|
533 |
+
" return metric.compute(predictions=predictions, references=labels)\n",
|
534 |
+
"\n",
|
535 |
+
"trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, \n",
|
536 |
+
" eval_dataset=eval_dataset,compute_metrics=compute_metrics)"
|
537 |
+
]
|
538 |
+
},
|
539 |
+
{
|
540 |
+
"cell_type": "code",
|
541 |
+
"execution_count": null,
|
542 |
+
"id": "7a1b2a94",
|
543 |
+
"metadata": {},
|
544 |
+
"outputs": [],
|
545 |
+
"source": [
|
546 |
+
"trainer = Trainer(\n",
|
547 |
+
" model=model,\n",
|
548 |
+
" args=training_args,\n",
|
549 |
+
" train_dataset=train_dataset,\n",
|
550 |
+
" eval_dataset=eval_dataset,\n",
|
551 |
+
" tokenizer=tokenizer,\n",
|
552 |
+
" data_collator=data_collator,\n",
|
553 |
+
" compute_metrics=compute_metrics,\n",
|
554 |
+
")"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "code",
|
559 |
+
"execution_count": null,
|
560 |
+
"id": "5a03352f",
|
561 |
+
"metadata": {},
|
562 |
+
"outputs": [],
|
563 |
+
"source": [
|
564 |
+
"# Launch the final evaluation \n",
|
565 |
+
"trainer.evaluate()"
|
566 |
+
]
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"cell_type": "code",
|
570 |
+
"execution_count": null,
|
571 |
+
"id": "e1c0323d",
|
572 |
+
"metadata": {},
|
573 |
+
"outputs": [],
|
574 |
+
"source": []
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"cell_type": "code",
|
578 |
+
"execution_count": null,
|
579 |
+
"id": "2f85b000",
|
580 |
+
"metadata": {},
|
581 |
+
"outputs": [],
|
582 |
+
"source": []
|
583 |
+
}
|
584 |
+
],
|
585 |
+
"metadata": {
|
586 |
+
"kernelspec": {
|
587 |
+
"display_name": "Python 3",
|
588 |
+
"language": "python",
|
589 |
+
"name": "python3"
|
590 |
+
},
|
591 |
+
"language_info": {
|
592 |
+
"codemirror_mode": {
|
593 |
+
"name": "ipython",
|
594 |
+
"version": 3
|
595 |
+
},
|
596 |
+
"file_extension": ".py",
|
597 |
+
"mimetype": "text/x-python",
|
598 |
+
"name": "python",
|
599 |
+
"nbconvert_exporter": "python",
|
600 |
+
"pygments_lexer": "ipython3",
|
601 |
+
"version": "3.8.8"
|
602 |
+
}
|
603 |
+
},
|
604 |
+
"nbformat": 4,
|
605 |
+
"nbformat_minor": 5
|
606 |
+
}
|
app.ipynb
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "2534f67e",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"# Installing Gradio\n",
|
11 |
+
"!pip install gradio transformers -q\n",
|
12 |
+
"\n",
|
13 |
+
"# Import the required Libraries\n",
|
14 |
+
"import gradio as gr\n",
|
15 |
+
"import numpy as np\n",
|
16 |
+
"import pandas as pd\n",
|
17 |
+
"import pickle\n",
|
18 |
+
"import transformers\n",
|
19 |
+
"from transformers import AutoTokenizer \n",
|
20 |
+
"from transformers import AutoConfig\n",
|
21 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
22 |
+
"from transformers import TFAutoModelForSequenceClassification\n",
|
23 |
+
"from transformers import pipeline\n",
|
24 |
+
"from scipy.special import softmax"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 2,
|
30 |
+
"id": "1dc8e034",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"# Requirements\n",
|
35 |
+
"model_path =\"HOLYBOY/Sentiment_Analysis\"\n",
|
36 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
|
37 |
+
"config = AutoConfig.from_pretrained(model_path)\n",
|
38 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_path)"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": 3,
|
44 |
+
"id": "02f78081",
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"# Preprocess text (username and link placeholders)\n",
|
49 |
+
"def preprocess(text):\n",
|
50 |
+
" new_text = []\n",
|
51 |
+
" for t in text.split(\" \"):\n",
|
52 |
+
" t = \"@user\" if t.startswith(\"@\") and len(t) > 1 else t\n",
|
53 |
+
" t = \"http\" if t.startswith(\"http\") else t\n",
|
54 |
+
" new_text.append(t)\n",
|
55 |
+
" return \" \".join(new_text)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 4,
|
61 |
+
"id": "70126857",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"# ---- Function to process the input and return prediction\n",
|
66 |
+
"def sentiment_analysis(text):\n",
|
67 |
+
" text = preprocess(text)\n",
|
68 |
+
"\n",
|
69 |
+
" encoded_input = tokenizer(text, return_tensors = \"pt\") # for PyTorch-based models\n",
|
70 |
+
" output = model(**encoded_input)\n",
|
71 |
+
" scores_ = output[0][0].detach().numpy()\n",
|
72 |
+
" scores_ = softmax(scores_)\n",
|
73 |
+
" \n",
|
74 |
+
" # Format output dict of scores\n",
|
75 |
+
" labels = [\"Negative\", \"Neutral\", \"Positive\"]\n",
|
76 |
+
" scores = {l:float(s) for (l,s) in zip(labels, scores_) }\n",
|
77 |
+
" \n",
|
78 |
+
" return scores"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "code",
|
83 |
+
"execution_count": 10,
|
84 |
+
"id": "4901894b",
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"name": "stdout",
|
89 |
+
"output_type": "stream",
|
90 |
+
"text": [
|
91 |
+
"Running on local URL: http://127.0.0.1:7865\n",
|
92 |
+
"\n",
|
93 |
+
"Could not create share link. Please check your internet connection or our status page: https://status.gradio.app. \n",
|
94 |
+
"\n",
|
95 |
+
"Also please ensure that your antivirus or firewall is not blocking the binary file located at: C:\\Users\\user\\anaconda3\\lib\\site-packages\\gradio\\frpc_windows_amd64_v0.2\n"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"data": {
|
100 |
+
"text/html": [
|
101 |
+
"<div><iframe src=\"http://127.0.0.1:7865/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
102 |
+
],
|
103 |
+
"text/plain": [
|
104 |
+
"<IPython.core.display.HTML object>"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
"metadata": {},
|
108 |
+
"output_type": "display_data"
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"data": {
|
112 |
+
"text/plain": []
|
113 |
+
},
|
114 |
+
"execution_count": 10,
|
115 |
+
"metadata": {},
|
116 |
+
"output_type": "execute_result"
|
117 |
+
}
|
118 |
+
],
|
119 |
+
"source": [
|
120 |
+
"# ---- Gradio app interface\n",
|
121 |
+
"app = gr.Interface(fn = sentiment_analysis,\n",
|
122 |
+
" inputs = gr.Textbox(\"Input your tweet to classify or use the example provided below...\"),\n",
|
123 |
+
" outputs = \"label\",\n",
|
124 |
+
" title = \"Public Perception of COVID-19 Vaccines\",\n",
|
125 |
+
" description = \"This app analyzes Perception of text based on tweets about COVID-19 Vaccines using a fine-tuned distilBERT model\",\n",
|
126 |
+
" interpretation = \"default\",\n",
|
127 |
+
" examples = [[\"The idea of introducing the vaccine is good\"],\n",
|
128 |
+
" [\"I am definately not taking the jab\"], \n",
|
129 |
+
" [\"The vaccine is bad and can cause serious health implications\"], \n",
|
130 |
+
" [\"I dont have any opinion \"]]\n",
|
131 |
+
" )\n",
|
132 |
+
"\n",
|
133 |
+
"app.launch(share =True)"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"id": "21753359",
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": []
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": null,
|
147 |
+
"id": "6c96cb45",
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": []
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"metadata": {
|
154 |
+
"kernelspec": {
|
155 |
+
"display_name": "Python 3",
|
156 |
+
"language": "python",
|
157 |
+
"name": "python3"
|
158 |
+
},
|
159 |
+
"language_info": {
|
160 |
+
"codemirror_mode": {
|
161 |
+
"name": "ipython",
|
162 |
+
"version": 3
|
163 |
+
},
|
164 |
+
"file_extension": ".py",
|
165 |
+
"mimetype": "text/x-python",
|
166 |
+
"name": "python",
|
167 |
+
"nbconvert_exporter": "python",
|
168 |
+
"pygments_lexer": "ipython3",
|
169 |
+
"version": "3.8.8"
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"nbformat": 4,
|
173 |
+
"nbformat_minor": 5
|
174 |
+
}
|
roberta_sentiment_analysis.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Roberta sentiment Analysis
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/10L1VfVMZLa62qTFdUIOURELW194TjJ4e
|
8 |
+
"""
|
9 |
+
|
10 |
+
# Install required libraries
|
11 |
+
!pip install datasets transformers huggingface_hub -q
|
12 |
+
|
13 |
+
# Import key libraries and packages
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
import pandas as pd
|
17 |
+
|
18 |
+
from datasets import load_dataset, load_metric
|
19 |
+
from huggingface_hub import notebook_login
|
20 |
+
from sklearn.model_selection import train_test_split
|
21 |
+
from transformers import AutoTokenizer, TrainingArguments, Trainer
|
22 |
+
from google.colab import files
|
23 |
+
from google.colab import drive
|
24 |
+
|
25 |
+
# Disable Weights & Biases
|
26 |
+
os.environ["WANDB_DISABLED"] = "true"
|
27 |
+
|
28 |
+
drive.mount('/content/drive')
|
29 |
+
|
30 |
+
# Load the datasets
|
31 |
+
train_df =pd.read_csv("/content/drive/MyDrive/PostBAP_ASSESSMENT/hugging.csv").dropna(axis = 0)
|
32 |
+
test_df = pd.read_csv("/content/drive/MyDrive/PostBAP_ASSESSMENT/Testhugging.csv").fillna("")
|
33 |
+
|
34 |
+
train_df.head()
|
35 |
+
|
36 |
+
test_df.head()
|
37 |
+
|
38 |
+
train_df.isnull().sum()
|
39 |
+
|
40 |
+
test_df.isnull().sum()
|
41 |
+
|
42 |
+
"""Fine-tuning the roberta model"""
|
43 |
+
|
44 |
+
train_df, eval = train_test_split(train_df, test_size=0.2, random_state=42, stratify= train_df['label'])
|
45 |
+
|
46 |
+
print(f"new dataframe shapes: train is {train_df.shape}, eval is {eval.shape}")
|
47 |
+
|
48 |
+
# Save splitted subsets
|
49 |
+
train_df.to_csv("/content/drive/MyDrive/PostBAP_ASSESSMENT/train_subset.csv", index=False)
|
50 |
+
eval.to_csv("/content/drive/MyDrive/PostBAP_ASSESSMENT/eval_subset.csv", index=False)
|
51 |
+
|
52 |
+
dataset = load_dataset('csv',
|
53 |
+
data_files={'train': '/content/drive/MyDrive/PostBAP_ASSESSMENT/train_subset.csv',
|
54 |
+
'eval': '/content/drive/MyDrive/PostBAP_ASSESSMENT/eval_subset.csv'}, encoding = "ISO-8859-1")
|
55 |
+
|
56 |
+
# Instantiate the tokenizer
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment", num_labels=3)
|
58 |
+
|
59 |
+
# Define helper functions
|
60 |
+
## Function to transform labels
|
61 |
+
def transform_labels(label):
|
62 |
+
|
63 |
+
label = label['label']
|
64 |
+
num = 0
|
65 |
+
if label == -1: #'Negative'
|
66 |
+
num = 0
|
67 |
+
elif label == 0: #'Neutral'
|
68 |
+
num = 1
|
69 |
+
elif label == 1: #'Positive'
|
70 |
+
num = 2
|
71 |
+
|
72 |
+
return {'labels': num}
|
73 |
+
|
74 |
+
## Function to tokenize data
|
75 |
+
def tokenize_data(example):
|
76 |
+
return tokenizer(example['safe_text'], padding='max_length',truncation=True, max_length = 256)
|
77 |
+
|
78 |
+
# Tokenize the tweets
|
79 |
+
dataset = dataset.map(tokenize_data, batched=True)
|
80 |
+
|
81 |
+
# Transform labels and limit the columns
|
82 |
+
remove_columns = ['tweet_id', 'label', 'safe_text', 'agreement']
|
83 |
+
dataset = dataset.map(transform_labels, remove_columns=remove_columns)
|
84 |
+
|
85 |
+
# Define training arguments
|
86 |
+
training_args = TrainingArguments(
|
87 |
+
"covid_tweets_sentiment_analysis_model",
|
88 |
+
num_train_epochs=4,
|
89 |
+
load_best_model_at_end=True,
|
90 |
+
evaluation_strategy="epoch",
|
91 |
+
save_strategy="epoch"
|
92 |
+
)
|
93 |
+
|
94 |
+
# Load the pretrained model
|
95 |
+
from transformers import AutoModelForSequenceClassification
|
96 |
+
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment", num_labels=3)
|
97 |
+
|
98 |
+
# Define evaluation metrics
|
99 |
+
metric = load_metric("accuracy")
|
100 |
+
|
101 |
+
def compute_metrics(eval_pred):
|
102 |
+
logits, labels = eval_pred
|
103 |
+
predictions = np.argmax(logits, axis=-1)
|
104 |
+
return metric.compute(predictions=predictions, references=labels)
|
105 |
+
|
106 |
+
# Instantiate the training and evaluation sets
|
107 |
+
train_dataset = dataset["train"].shuffle(seed=24)
|
108 |
+
eval_dataset = dataset["eval"].shuffle(seed=24)
|
109 |
+
|
110 |
+
#converting training data to PyTorch tensors to speed up training and adding padding:
|
111 |
+
from transformers import DataCollatorWithPadding
|
112 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
113 |
+
|
114 |
+
# Instantiate the trainer
|
115 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,compute_metrics=compute_metrics)
|
116 |
+
trainer.train()
|
117 |
+
|
118 |
+
# Reinstantiate the trainer for evaluation
|
119 |
+
trainer = Trainer(
|
120 |
+
model=model,
|
121 |
+
args=training_args,
|
122 |
+
train_dataset=train_dataset,
|
123 |
+
eval_dataset=eval_dataset,
|
124 |
+
tokenizer=tokenizer,
|
125 |
+
data_collator=data_collator,
|
126 |
+
compute_metrics=compute_metrics,
|
127 |
+
)
|
128 |
+
|
129 |
+
# Launch the final evaluation
|
130 |
+
trainer.evaluate()
|
131 |
+
|
132 |
+
# Login to HF hub
|
133 |
+
notebook_login()
|
134 |
+
|
135 |
+
# Push model and tokenizer to HF Hub
|
136 |
+
model.push_to_hub("MavisAJ/Sentiment_analysis_roberta_model")
|
137 |
+
tokenizer.push_to_hub("MavisAJ/Sentiment_analysis_roberta_model")
|