faizhalas commited on
Commit
19b9cbc
β€’
1 Parent(s): 33a4df1

Update pages/4 Sunburst.py

Browse files
Files changed (1) hide show
  1. pages/4 Sunburst.py +83 -364
pages/4 Sunburst.py CHANGED
@@ -1,19 +1,8 @@
 
1
  import streamlit as st
2
  import pandas as pd
3
- from sklearn.feature_extraction.text import CountVectorizer
4
- from nltk.tokenize import word_tokenize
5
- from nltk.corpus import stopwords
6
- import nltk
7
- import spacy
8
- from burst_detection import burst_detection, enumerate_bursts, burst_weights
9
- import matplotlib.pyplot as plt
10
- import os
11
- import io
12
- import math
13
  import numpy as np
14
- import plotly.graph_objects as go
15
- from plotly.subplots import make_subplots
16
- import plotly.io as pio
17
  import sys
18
 
19
  #===config===
@@ -42,384 +31,114 @@ with st.popover("πŸ”— Menu"):
42
  st.page_link("pages/4 Sunburst.py", label="Sunburst", icon="4️⃣")
43
  st.page_link("pages/5 Burst Detection.py", label="Burst Detection", icon="5️⃣")
44
  st.page_link("pages/6 Keywords Stem.py", label="Keywords Stem", icon="6️⃣")
45
-
46
- st.header("Burst Detection", anchor=False)
47
  st.subheader('Put your file here...', anchor=False)
48
 
49
  #===clear cache===
50
  def reset_all():
51
  st.cache_data.clear()
52
 
53
- # Initialize NLP model
54
- nlp = spacy.load("en_core_web_md")
55
-
56
  @st.cache_data(ttl=3600)
57
- def upload(extype):
58
- df = pd.read_csv(uploaded_file)
59
- #lens.org
60
- if 'Publication Year' in df.columns:
61
- df.rename(columns={'Publication Year': 'Year', 'Citing Works Count': 'Cited by',
62
- 'Publication Type': 'Document Type', 'Source Title': 'Source title'}, inplace=True)
63
- return df
64
-
65
- @st.cache_data(ttl=3600)
66
- def get_ext(uploaded_file):
67
  extype = uploaded_file.name
68
  return extype
69
 
70
  @st.cache_data(ttl=3600)
71
- def get_minmax(df):
72
- MIN = int(df['Year'].min())
73
- MAX = int(df['Year'].max())
74
- GAP = MAX - MIN
75
- return MIN, MAX, GAP
 
 
76
 
77
  @st.cache_data(ttl=3600)
78
  def conv_txt(extype):
79
  col_dict = {'TI': 'Title',
80
  'SO': 'Source title',
81
  'DT': 'Document Type',
 
 
82
  'AB': 'Abstract',
83
- 'PY': 'Year'}
84
- df = pd.read_csv(uploaded_file, sep='\t', lineterminator='\r')
85
- df.rename(columns=col_dict, inplace=True)
86
- return df
87
-
88
- # Helper Functions
89
- @st.cache_data(ttl=3600)
90
- def get_column_name(df, possible_names):
91
- """Find and return existing column names from a list of possible names."""
92
- for name in possible_names:
93
- if name in df.columns:
94
- return name
95
- raise ValueError(f"None of the possible names {possible_names} found in DataFrame columns.")
96
-
97
- @st.cache_data(ttl=3600)
98
- def preprocess_text(text):
99
- """Lemmatize and remove stopwords from text."""
100
- return ' '.join([token.lemma_.lower() for token in nlp(text) if token.is_alpha and not token.is_stop])
101
-
102
- @st.cache_data(ttl=3600)
103
- def load_data(uploaded_file):
104
- """Load data from the uploaded file."""
105
- extype = get_ext(uploaded_file)
106
- if extype.endswith('.csv'):
107
- df = upload(extype)
108
- elif extype.endswith('.txt'):
109
- df = conv_txt(extype)
110
-
111
- df['Year'] = pd.to_numeric(df['Year'], errors='coerce')
112
- df = df.dropna(subset=['Year'])
113
- df['Year'] = df['Year'].astype(int)
114
-
115
- if 'Title' in df.columns and 'Abstract' in df.columns:
116
- coldf = ['Abstract', 'Title']
117
- elif 'Title' in df.columns:
118
- coldf = ['Title']
119
- elif 'Abstract' in df.columns:
120
- coldf = ['Abstract']
121
- else:
122
- coldf = sorted(df.select_dtypes(include=['object']).columns.tolist())
123
-
124
- MIN, MAX, GAP = get_minmax(df)
125
 
126
- return df, coldf, MIN, MAX, GAP
127
-
128
- @st.cache_data(ttl=3600)
129
- def clean_data(df):
130
-
131
- years = list(range(YEAR[0],YEAR[1]+1))
132
- df = df.loc[df['Year'].isin(years)]
133
-
134
- # Preprocess text
135
- df['processed'] = df.apply(lambda row: preprocess_text(f"{row.get(col_name, '')}"), axis=1)
136
-
137
- # Vectorize processed text
138
- vectorizer = CountVectorizer(lowercase=False, tokenizer=lambda x: x.split())
139
- X = vectorizer.fit_transform(df['processed'].tolist())
140
-
141
- # Create DataFrame from the Document-Term Matrix (DTM)
142
- dtm = pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out(), index=df['Year'].values)
143
- yearly_term_frequency = dtm.groupby(dtm.index).sum()
144
-
145
- # User inputs for top words analysis and exclusions
146
- excluded_words = [word.strip() for word in excluded_words_input.split(',')]
147
-
148
- # Identify top words, excluding specified words
149
- filtered_words = [word for word in yearly_term_frequency.columns if word not in excluded_words]
150
- top_words = yearly_term_frequency[filtered_words].sum().nlargest(top_n).index.tolist()
151
-
152
- return yearly_term_frequency, top_words
153
-
154
- @st.cache_data(ttl=3600)
155
- def apply_burst_detection(top_words, data):
156
- all_bursts_list = []
157
-
158
- start_year = int(data.index.min())
159
- end_year = int(data.index.max())
160
- all_years = range(start_year, end_year + 1)
161
-
162
- continuous_years = pd.Series(index=all_years, data=0) # Start with a series of zeros for all years
163
-
164
- years = continuous_years.index.tolist()
165
-
166
- all_freq_data = pd.DataFrame(index=years)
167
-
168
- for i, word in enumerate(top_words, start=1):
169
- # Update with actual counts where available
170
- word_counts = data[word].reindex(continuous_years.index, fill_value=0)
171
-
172
- # Convert years and counts to lists for burst detection
173
- r = continuous_years.index.tolist() # List of all years
174
- r = np.array(r, dtype=int)
175
- d = word_counts.values.tolist() # non-zero counts
176
- d = np.array(d, dtype=float)
177
- y = r.copy()
178
-
179
- if len(r) > 0 and len(d) > 0:
180
- n = len(r)
181
- q, d, r, p = burst_detection(d, r, n, s=2.0, gamma=1.0, smooth_win=1)
182
- bursts = enumerate_bursts(q, word)
183
- bursts = burst_weights(bursts, r, d, p)
184
- all_bursts_list.append(bursts)
185
-
186
- freq_data = yearly_term_frequency[word].reindex(years, fill_value=0)
187
- all_freq_data[word] = freq_data
188
-
189
- all_bursts = pd.concat(all_bursts_list, ignore_index=True)
190
-
191
- num_unique_labels = len(all_bursts['label'].unique())
192
-
193
- num_rows = math.ceil(top_n / 2)
194
-
195
- if running_total == "Running total":
196
- all_freq_data = all_freq_data.cumsum()
197
-
198
- return all_bursts, all_freq_data, num_unique_labels, num_rows
199
-
200
- @st.cache_data(ttl=3600)
201
- def convert_df(df):
202
- return df.to_csv().encode("utf-8")
203
-
204
- @st.cache_data(ttl=3600)
205
- def scattervis(bursts, freq_data):
206
- freq_data.reset_index(inplace=True)
207
- freq_data.rename(columns={"index": "Year"}, inplace=True)
208
-
209
- freq_data_melted = freq_data.melt(id_vars=["Year"], var_name="Category", value_name="Value")
210
- freq_data_melted = freq_data_melted[freq_data_melted["Value"] > 0]
211
- wordlist = freq_data_melted["Category"].unique()
212
-
213
- years = freq_data["Year"].tolist()
214
- bursts["begin"] = bursts["begin"].apply(lambda x: years[min(x, len(years) - 1)] if x < len(years) else None)
215
- bursts["end"] = bursts["end"].apply(lambda x: years[min(x, len(years) - 1)] if x < len(years) else None)
216
- burst_points = []
217
-
218
- for _, row in bursts.iterrows():
219
- for year in range(row["begin"], row["end"] + 1):
220
- burst_points.append((year, row["label"], row["weight"]))
221
-
222
- burst_points_df = pd.DataFrame(burst_points, columns=["Year", "Category", "Weight"])
223
-
224
- fig = go.Figure()
225
-
226
- # scatter trace for burst points
227
- fig.add_trace(go.Scatter(
228
- x=burst_points_df["Year"],
229
- y=burst_points_df["Category"],
230
- mode='markers',
231
- marker=dict(
232
- symbol='square',
233
- size=40,
234
- color='red',
235
- opacity=0.5),
236
- hoverinfo='text',
237
- text=burst_points_df["Weight"],
238
- showlegend=False
239
- ))
240
-
241
- # scatter trace for freq_data
242
- fig.add_trace(go.Scatter(
243
- x=freq_data_melted["Year"],
244
- y=freq_data_melted["Category"],
245
- mode='markers+text',
246
- marker=dict(
247
- symbol='square',
248
- size=30,
249
- color=freq_data_melted["Value"],
250
- colorscale='Blues',
251
- showscale=False),
252
- text=freq_data_melted["Value"],
253
- textposition="middle center",
254
- textfont=dict(
255
- size=16,
256
- color=['white' if value > freq_data_melted["Value"].max()/2 else 'black' for value in freq_data_melted["Value"]])
257
- ))
258
-
259
- min_year = min(years)
260
- max_year = max(years)
261
-
262
- fig.update_layout(
263
- xaxis=dict(tickmode='linear', dtick=1, range=[(min_year-1), (max_year+1)], tickfont = dict(size=16), automargin=True, showgrid=False, zeroline=False),
264
- yaxis=dict(tickvals=wordlist, ticktext=wordlist, tickmode='array', tickfont = dict(size=16), automargin=True, showgrid=False, zeroline=False),
265
- plot_bgcolor='white',
266
- paper_bgcolor='white',
267
- showlegend=False,
268
- margin=dict(l=1, r=1, t=1, b=1),
269
- height=top_n*50+2,
270
- width=(max_year-min_year)*52+100,
271
- autosize=False
272
- )
273
-
274
- fig.write_image("scatter_plot.png")
275
- st.image("scatter_plot.png")
276
- pio.write_image(fig, 'result.png', scale=4)
277
-
278
- @st.cache_data(ttl=3600)
279
- def linegraph(bursts, freq_data):
280
- fig = make_subplots(rows=num_rows, cols=2, subplot_titles=freq_data.columns[:top_n])
281
-
282
- row, col = 1, 1
283
- for i, column in enumerate(freq_data.columns[:top_n]):
284
- fig.add_trace(go.Scatter(
285
- x=freq_data.index, y=freq_data[column], mode='lines+markers+text', name=column,
286
- line_shape='linear',
287
- hoverinfo='text',
288
- hovertext=[f"Year: {index}<br>Frequency: {freq}" for index, freq in zip(freq_data.index, freq_data[column])],
289
- text=freq_data[column],
290
- textposition='top center'
291
- ), row=row, col=col)
292
-
293
- # Add area charts
294
- for _, row_data in bursts[bursts['label'] == column].iterrows():
295
- x_values = freq_data.index[row_data['begin']:row_data['end']+1]
296
- y_values = freq_data[column][row_data['begin']:row_data['end']+1]
297
-
298
- #middle_y = sum(y_values) / len(y_values)
299
- y_post = min(freq_data[column]) + 1 if running_total == "Running total" else sum(y_values) / len(y_values)
300
- x_offset = 0.1
301
-
302
- # Add area chart
303
- fig.add_trace(go.Scatter(
304
- x=x_values,
305
- y=y_values,
306
- fill='tozeroy', mode='lines', fillcolor='rgba(0,100,80,0.2)',
307
- ), row=row, col=col)
308
-
309
- align_value = "left" if running_total == "Running total" else "center"
310
- valign_value = "bottom" if running_total == "Running total" else "middle"
311
-
312
- # Add annotation for weight at the bottom
313
- fig.add_annotation(
314
- x=x_values[0] + x_offset,
315
- y=y_post,
316
- text=f"Weight: {row_data['weight']:.2f}",
317
- showarrow=False,
318
- font=dict(
319
- color="black",
320
- size=12),
321
- align=align_value,
322
- valign=valign_value,
323
- textangle=270,
324
- row=row, col=col
325
- )
326
-
327
- col += 1
328
- if col > 2:
329
- col = 1
330
- row += 1
331
-
332
- fig.update_layout(
333
- showlegend=False,
334
- margin=dict(l=20, r=20, t=100, b=20),
335
- height=num_rows * 500,
336
- width=1500
337
- )
338
-
339
- fig.write_image("line_graph.png")
340
- st.image("line_graph.png")
341
- pio.write_image(fig, 'result.png', scale=4)
342
-
343
- @st.cache_data(ttl=3600)
344
- def download_result(freq_data, bursts):
345
- csv1 = convert_df(freq_data)
346
- csv2 = convert_df(bursts)
347
- return csv1, csv2
348
-
349
  uploaded_file = st.file_uploader('', type=['csv', 'txt'], on_change=reset_all)
350
 
351
  if uploaded_file is not None:
352
  try:
353
- c1, c2, c3 = st.columns([3,3.5,3.5])
354
- top_n = c1.number_input("Number of top words to analyze", min_value=5, value=10, step=1, on_change=reset_all)
355
- viz_selected = c2.selectbox("Option for visualization",
356
- ("Line graph", "Scatter plot"), on_change=reset_all)
357
- running_total = c3.selectbox("Option for counting words",
358
- ("Running total", "By occurrences each year"), on_change=reset_all)
359
-
360
- d1, d2 = st.columns([3,7])
361
- df, coldf, MIN, MAX, GAP = load_data(uploaded_file)
362
- col_name = d1.selectbox("Select column to analyze",
363
- (coldf), on_change=reset_all)
364
- excluded_words_input = d2.text_input("Words to exclude (comma-separated)", on_change=reset_all)
365
-
366
- if (GAP != 0):
367
- YEAR = st.slider('Year', min_value=MIN, max_value=MAX, value=(MIN, MAX), on_change=reset_all)
368
- else:
369
- e1.write('You only have data in ', (MAX))
370
- sys.exit(1)
371
-
372
- yearly_term_frequency, top_words = clean_data(df)
373
 
374
- bursts, freq_data, num_unique_labels, num_rows = apply_burst_detection(top_words, yearly_term_frequency)
375
-
376
- tab1, tab2, tab3 = st.tabs(["πŸ“ˆ Generate visualization", "πŸ“ƒ Reference", "πŸ““ Recommended Reading"])
377
-
378
- with tab1:
379
- if bursts.empty:
380
- st.warning('We cannot detect any bursts', icon='⚠️')
381
-
 
 
 
 
 
 
 
 
 
 
 
 
382
  else:
383
- if num_unique_labels == top_n:
384
- st.info(f'We detect a burst on {num_unique_labels} word(s)', icon="ℹ️")
385
- elif num_unique_labels < top_n:
386
- st.info(f'We only detect a burst on {num_unique_labels} word(s), which is {top_n - num_unique_labels} fewer than the top word(s)', icon="ℹ️")
387
-
388
- if viz_selected == "Line graph":
389
- linegraph(bursts, freq_data)
390
-
391
- elif viz_selected =="Scatter plot":
392
- scattervis(bursts, freq_data)
393
-
394
- csv1, csv2 = download_result(freq_data, bursts)
395
- e1, e2, e3 = st.columns(3)
396
- with open('result.png', "rb") as file:
397
- btn = e1.download_button(
398
- label="πŸ“Š Download high resolution image",
399
- data=file,
400
- file_name="burst.png",
401
- mime="image/png")
402
-
403
- e2.download_button(
404
- "πŸ‘‰ Press to download list of top words",
405
- csv1,
406
- "top-keywords.csv",
407
- "text/csv")
 
408
 
409
- e3.download_button(
410
- "πŸ‘‰ Press to download the list of detected bursts",
411
- csv2,
412
- "burst.csv",
413
- "text/csv")
414
-
 
415
  with tab2:
416
- st.markdown('**Kleinberg, J. (2002). Bursty and hierarchical structure in streams. Knowledge Discovery and Data Mining.** https://doi.org/10.1145/775047.775061')
417
-
418
- with tab3:
419
- st.markdown('**Li, M., Zheng, Z., & Yi, Q. (2024). The landscape of hot topics and research frontiers in Kawasaki disease: scientometric analysis. Heliyon, 10(8), e29680–e29680.** https://doi.org/10.1016/j.heliyon.2024.e29680')
420
- st.markdown('**DomiciΓ‘n MΓ‘tΓ©, Ni Made Estiyanti and Novotny, A. (2024) β€˜How to support innovative small firms? Bibliometric analysis and visualization of start-up incubation’, Journal of Innovation and Entrepreneurship, 13(1).** https://doi.org/10.1186/s13731-024-00361-z')
421
- st.markdown('**Lamba, M., Madhusudhan, M. (2022). Burst Detection. In: Text Mining for Information Professionals. Springer, Cham.** https://doi.org/10.1007/978-3-030-85085-2_6')
422
-
423
  except:
424
  st.error("Please ensure that your file is correct. Please contact us if you find that this is an error.", icon="🚨")
425
  st.stop()
 
1
+ #===import module===
2
  import streamlit as st
3
  import pandas as pd
4
+ import plotly.express as px
 
 
 
 
 
 
 
 
 
5
  import numpy as np
 
 
 
6
  import sys
7
 
8
  #===config===
 
31
  st.page_link("pages/4 Sunburst.py", label="Sunburst", icon="4️⃣")
32
  st.page_link("pages/5 Burst Detection.py", label="Burst Detection", icon="5️⃣")
33
  st.page_link("pages/6 Keywords Stem.py", label="Keywords Stem", icon="6️⃣")
34
+
35
+ st.header("Sunburst Visualization", anchor=False)
36
  st.subheader('Put your file here...', anchor=False)
37
 
38
  #===clear cache===
39
  def reset_all():
40
  st.cache_data.clear()
41
 
42
+ #===check type===
 
 
43
  @st.cache_data(ttl=3600)
44
+ def get_ext(extype):
 
 
 
 
 
 
 
 
 
45
  extype = uploaded_file.name
46
  return extype
47
 
48
  @st.cache_data(ttl=3600)
49
+ def upload(extype):
50
+ papers = pd.read_csv(uploaded_file)
51
+ #lens.org
52
+ if 'Publication Year' in papers.columns:
53
+ papers.rename(columns={'Publication Year': 'Year', 'Citing Works Count': 'Cited by',
54
+ 'Publication Type': 'Document Type', 'Source Title': 'Source title'}, inplace=True)
55
+ return papers
56
 
57
  @st.cache_data(ttl=3600)
58
  def conv_txt(extype):
59
  col_dict = {'TI': 'Title',
60
  'SO': 'Source title',
61
  'DT': 'Document Type',
62
+ 'DE': 'Author Keywords',
63
+ 'ID': 'Keywords Plus',
64
  'AB': 'Abstract',
65
+ 'TC': 'Cited by',
66
+ 'PY': 'Year',}
67
+ papers = pd.read_csv(uploaded_file, sep='\t', lineterminator='\r')
68
+ papers.rename(columns=col_dict, inplace=True)
69
+ return papers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ #===Read data===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  uploaded_file = st.file_uploader('', type=['csv', 'txt'], on_change=reset_all)
73
 
74
  if uploaded_file is not None:
75
  try:
76
+ extype = get_ext(uploaded_file)
77
+ if extype.endswith('.csv'):
78
+ papers = upload(extype)
79
+
80
+ elif extype.endswith('.txt'):
81
+ papers = conv_txt(extype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ @st.cache_data(ttl=3600)
84
+ def get_minmax(extype):
85
+ extype = extype
86
+ MIN = int(papers['Year'].min())
87
+ MAX = int(papers['Year'].max())
88
+ GAP = MAX - MIN
89
+ return papers, MIN, MAX, GAP
90
+
91
+ tab1, tab2 = st.tabs(["πŸ“ˆ Generate visualization", "πŸ““ Recommended Reading"])
92
+
93
+ with tab1:
94
+ #===sunburst===
95
+ try:
96
+ papers, MIN, MAX, GAP = get_minmax(extype)
97
+ except KeyError:
98
+ st.error('Error: Please check again your columns.')
99
+ sys.exit(1)
100
+
101
+ if (GAP != 0):
102
+ YEAR = st.slider('Year', min_value=MIN, max_value=MAX, value=(MIN, MAX), on_change=reset_all)
103
  else:
104
+ st.write('You only have data in ', (MAX))
105
+ YEAR = (MIN, MAX)
106
+
107
+ @st.cache_data(ttl=3600)
108
+ def listyear(extype):
109
+ global papers
110
+ years = list(range(YEAR[0],YEAR[1]+1))
111
+ papers = papers.loc[papers['Year'].isin(years)]
112
+ return years, papers
113
+
114
+ @st.cache_data(ttl=3600)
115
+ def vis_sunbrust(extype):
116
+ papers['Cited by'] = papers['Cited by'].fillna(0)
117
+ vis = pd.DataFrame()
118
+ vis[['doctype','source','citby','year']] = papers[['Document Type','Source title','Cited by','Year']]
119
+ viz=vis.groupby(['doctype', 'source', 'year'])['citby'].agg(['sum','count']).reset_index()
120
+ viz.rename(columns={'sum': 'cited by', 'count': 'total docs'}, inplace=True)
121
+
122
+ fig = px.sunburst(viz, path=['doctype', 'source', 'year'], values='total docs',
123
+ color='cited by',
124
+ color_continuous_scale='RdBu',
125
+ color_continuous_midpoint=np.average(viz['cited by'], weights=viz['total docs']))
126
+ fig.update_layout(height=800, width=1200)
127
+ return fig
128
+
129
+ years, papers = listyear(extype)
130
 
131
+ if {'Document Type','Source title','Cited by','Year'}.issubset(papers.columns):
132
+ fig = vis_sunbrust(extype)
133
+ st.plotly_chart(fig, height=800, width=1200) #use_container_width=True)
134
+
135
+ else:
136
+ st.error('We require these columns: Document Type, Source title, Cited by, Year', icon="🚨")
137
+
138
  with tab2:
139
+ st.markdown('**numpy.average β€” NumPy v1.24 Manual. (n.d.). Numpy.Average β€” NumPy v1.24 Manual.** https://numpy.org/doc/stable/reference/generated/numpy.average.html')
140
+ st.markdown('**Sunburst. (n.d.). Sunburst Charts in Python.** https://plotly.com/python/sunburst-charts/')
141
+
 
 
 
 
142
  except:
143
  st.error("Please ensure that your file is correct. Please contact us if you find that this is an error.", icon="🚨")
144
  st.stop()