nfel commited on
Commit
5ee0089
1 Parent(s): c7d9a91

First commit.

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. download.py +3 -0
  3. logo.png +0 -0
  4. requirements.txt +1 -0
  5. run.py +313 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ .idea/
download.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import thermostat
2
+
3
+ data = thermostat.load("imdb-bert-lig", cache_dir="~/datasets")
logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ thermostat-datasets==1.0.1
run.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import glob
3
+ import json
4
+ import pandas as pd
5
+ import streamlit as st
6
+ import sys
7
+ import textwrap
8
+
9
+ from thermostat import load
10
+ from thermostat.data.thermostat_configs import builder_configs
11
+
12
+
13
+ nlp = datasets
14
+
15
+ HTML_WRAPPER = """<div>{}</div>"""
16
+ #HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
17
+ # margin-bottom: 2.5rem">{}</div>"""
18
+ MAX_SIZE = 40000000000
19
+ if len(sys.argv) > 1:
20
+ path_to_datasets = sys.argv[1]
21
+ else:
22
+ path_to_datasets = None
23
+
24
+
25
+ # Hack to extend the width of the main pane.
26
+ def _max_width_():
27
+ max_width_str = f"max-width: 1000px;"
28
+ st.markdown(
29
+ f"""
30
+ <style>
31
+ .reportview-container .main .block-container{{
32
+ {max_width_str}
33
+ }}
34
+ th {{
35
+ text-align: left;
36
+ font-size: 110%;
37
+
38
+
39
+ }}
40
+
41
+ tr:hover {{
42
+ background-color: #ffff99;
43
+ }}
44
+
45
+ </style>
46
+ """,
47
+ unsafe_allow_html=True,
48
+ )
49
+
50
+
51
+ _max_width_()
52
+
53
+
54
+ def render_features(features):
55
+ if isinstance(features, dict):
56
+ return {k: render_features(v) for k, v in features.items()}
57
+ if isinstance(features, nlp.features.ClassLabel):
58
+ return features.names
59
+
60
+ if isinstance(features, nlp.features.Value):
61
+ return features.dtype
62
+
63
+ if isinstance(features, nlp.features.Sequence):
64
+ return {"[]": render_features(features.feature)}
65
+ return features
66
+
67
+
68
+ app_state = st.experimental_get_query_params()
69
+ start = True
70
+ loaded = True
71
+ INITIAL_SELECTION = ""
72
+
73
+ app_state.setdefault("dataset", "glue")
74
+ if len(app_state.get("dataset", [])) == 1:
75
+ app_state["dataset"] = app_state["dataset"][0]
76
+ INITIAL_SELECTION = app_state["dataset"]
77
+ #print(INITIAL_SELECTION)
78
+
79
+ if start:
80
+ # Logo and sidebar decoration.
81
+ st.sidebar.markdown(
82
+ """<center>
83
+ <a href="https://github.com/DFKI-NLP/thermostat">
84
+ </a>
85
+ </center>""",
86
+ unsafe_allow_html=True,
87
+ )
88
+ st.sidebar.image("logo.png", width=300)
89
+ st.sidebar.markdown(
90
+ "<center><h2><a href='https://github.com/DFKI-NLP/thermostat'>github/DFKI-NLP/thermostat</h2></a></center>",
91
+ unsafe_allow_html=True,
92
+ )
93
+ st.sidebar.markdown(
94
+ """
95
+ <center>
96
+ <a target="_blank" href="https://huggingface.co/docs/datasets/">datasets Docs</a>
97
+ </center>""",
98
+ unsafe_allow_html=True,
99
+ )
100
+ st.sidebar.subheader("")
101
+
102
+ # Interaction with the nlp libary.
103
+ # @st.cache
104
+ def get_confs():
105
+ """ Get the list of confs for a dataset. """
106
+ confs = builder_configs
107
+ if confs and len(confs) > 1:
108
+ return confs
109
+ else:
110
+ return []
111
+
112
+ # @st.cache(allow_output_mutation=True)
113
+ def get(conf):
114
+ """ Get a dataset from name and conf """
115
+ ds = load(conf, cache_dir=path_to_datasets)
116
+ return ds, False
117
+
118
+ # Dataset select box.
119
+ datasets = []
120
+ selection = None
121
+
122
+ if path_to_datasets is None:
123
+ list_of_datasets = nlp.list_datasets(with_community_datasets=False)
124
+ else:
125
+ list_of_datasets = sorted(glob.glob(path_to_datasets + "*"))
126
+ for i, dataset in enumerate(list_of_datasets):
127
+ dataset = dataset.split("/")[-1]
128
+ if INITIAL_SELECTION and dataset == INITIAL_SELECTION:
129
+ selection = i
130
+ datasets.append(dataset)
131
+
132
+ st.experimental_set_query_params(**app_state)
133
+
134
+ # Side bar Configurations.
135
+ configs = get_confs()
136
+ conf_avail = len(configs) > 0
137
+ conf_option = None
138
+ if conf_avail:
139
+ start = 0
140
+ for i, conf in enumerate(configs):
141
+ if conf.name == app_state.get("config", None):
142
+ start = i
143
+ conf_option = st.sidebar.selectbox(
144
+ "Thermostat configuration", configs, index=start, format_func=lambda a: a.name
145
+ )
146
+ app_state["config"] = conf_option.name
147
+
148
+ else:
149
+ if "config" in app_state:
150
+ del app_state["config"]
151
+ st.experimental_set_query_params(**app_state)
152
+
153
+ dts, fail = get(str(conf_option.name) if conf_option else None)
154
+
155
+ # Main panel setup.
156
+ if fail:
157
+ st.markdown(
158
+ "Dataset is too large to browse or requires manual download. Check it out in the datasets library! \n\n "
159
+ "Size: "
160
+ + str(dts.info.size_in_bytes)
161
+ + "\n\n Instructions: "
162
+ + str(dts.manual_download_instructions)
163
+ )
164
+ else:
165
+ d = dts
166
+ keys = list(d[0].__dict__.keys())
167
+
168
+ st.header(
169
+ "Thermostat configuration: "
170
+ + (conf_option.name if conf_option else "")
171
+ )
172
+
173
+ st.markdown(
174
+ "*Homepage*: "
175
+ + d.info.homepage
176
+ )
177
+
178
+ md = """
179
+ %s
180
+ """ % (
181
+ d.info.description.replace("\\", " ")
182
+ )
183
+ st.markdown(md)
184
+
185
+ step = 50
186
+ offset = st.sidebar.number_input(
187
+ "Offset (Size: %d)" % len(d),
188
+ min_value=0,
189
+ max_value=int(len(d)) - step,
190
+ value=0,
191
+ step=step,
192
+ )
193
+
194
+ citation = None #st.sidebar.checkbox("Show Citations", False)
195
+ table = not st.sidebar.checkbox("Show List View", False)
196
+ show_features = st.sidebar.checkbox("Show Features", True)
197
+ show_atts = st.sidebar.checkbox("Show Attribution Scores", False)
198
+ md = """
199
+ ```
200
+ %s
201
+ ```
202
+ """ % (
203
+ d.info.citation.replace("\\", "").replace("}", " }").replace("{", "{ "),
204
+ )
205
+ if citation:
206
+ st.markdown(md)
207
+ # st.text("Features:")
208
+ #if show_features:
209
+ # on_keys = st.multiselect("Features", keys, keys)
210
+ # #st.write(render_features(d.features))
211
+ #else:
212
+ on_keys = keys
213
+
214
+ # Remove some keys
215
+ on_keys = [k for k in on_keys if k in ['predictions', 'true_label', 'predicted_label']]
216
+
217
+ if not table:
218
+ # Full view.
219
+ for item in range(offset, offset + step):
220
+ st.text(" ")
221
+ st.text(" ---- #" + str(item))
222
+ st.text(" ")
223
+ # Use st to write out.
224
+ for k in on_keys:
225
+ v = getattr(d[item], k)
226
+ st.subheader(k)
227
+ if isinstance(v, str):
228
+ out = v
229
+ st.text(textwrap.fill(out, width=120))
230
+ elif (
231
+ isinstance(v, bool)
232
+ or isinstance(v, int)
233
+ or isinstance(v, float)
234
+ ):
235
+ st.text(v)
236
+ else:
237
+ st.write(v)
238
+
239
+ else:
240
+ # Table view. Use Pandas.
241
+ df, heatmap_htmls = [], []
242
+ for item in range(offset, offset + step):
243
+ df_item = {}
244
+ df_item["_number"] = item
245
+ for k in on_keys:
246
+ v = getattr(d[item], k)
247
+
248
+ # Remove [PAD] tokens from attributions and input_ids
249
+ if k in ['attributions', 'input_ids']:
250
+ v = [vi for vi in v if vi != 0 or vi != 0.0]
251
+
252
+ if isinstance(v, str):
253
+ out = v
254
+ df_item[k] = textwrap.fill(out, width=50)
255
+ elif (
256
+ isinstance(v, bool)
257
+ or isinstance(v, int)
258
+ or isinstance(v, float)
259
+ ):
260
+ df_item[k] = v
261
+ else:
262
+ out = json.dumps(v, indent=2, sort_keys=True)
263
+ df_item[k] = out
264
+
265
+ # Add heatmap viz
266
+ html = getattr(d[item], 'heatmap').render(labels=show_atts)
267
+ html = html.replace("\n", " ")
268
+ heatmap_htmls.append(HTML_WRAPPER.format(html))
269
+
270
+ df.append(df_item)
271
+ df2 = df
272
+ df = pd.DataFrame(df).set_index("_number")
273
+
274
+ def hover(hover_color="#ffff99"):
275
+ return dict(
276
+ selector="tr:hover",
277
+ props=[("background-color", "%s" % hover_color)],
278
+ )
279
+
280
+ styles = [
281
+ hover(),
282
+ dict(
283
+ selector="th",
284
+ props=[("font-size", "150%"), ("text-align", "center")],
285
+ ),
286
+ dict(selector="caption", props=[("caption-side", "bottom")]),
287
+ ]
288
+ # Table view. Use pands styling.
289
+ style = df.style.set_properties(
290
+ **{"text-align": "left", "white-space": "pre"}
291
+ ).set_table_styles([dict(selector="th", props=[("text-align", "left")])])
292
+ style = style.set_table_styles(styles) # Setting the style appears to be broken for streamlit+pandas
293
+
294
+ for i, heatmap_html in enumerate(heatmap_htmls):
295
+ st.write(HTML_WRAPPER.format(heatmap_html), unsafe_allow_html=True)
296
+ st.table(df.iloc[[i]])
297
+ st.markdown(""" --- """)
298
+
299
+ # Additional dataset installation and sidebar properties.
300
+ md = """
301
+ ### Code
302
+
303
+ ```python
304
+ !pip install thermostat_datasets
305
+ from thermostat import load
306
+ dataset = load(
307
+ '%s)
308
+ ```
309
+
310
+ """ % (
311
+ (conf_option.name + "'") if conf_option else "",
312
+ )
313
+ st.sidebar.markdown(md)