ludusc commited on
Commit
70dfa79
1 Parent(s): 83d8189

view performance, cleaning up

Browse files
.gitignore CHANGED
@@ -181,4 +181,6 @@ dmypy.json
181
  .pytype/
182
 
183
  # Cython debug symbols
184
- cython_debug/
 
 
 
181
  .pytype/
182
 
183
  # Cython debug symbols
184
+ cython_debug/
185
+
186
+ data/images/
backend/disentangle_concepts.py CHANGED
@@ -7,6 +7,21 @@ from umap import UMAP
7
  import PIL
8
 
9
  def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  abstracts = np.array([float(ann) for ann in df[type_bin]])
11
  abstract_idxs = list(np.argsort(abstracts))[:samples]
12
  repr_idxs = list(np.argsort(abstracts))[-samples:]
@@ -20,17 +35,32 @@ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=
20
  print('Val performance SVM', svc.score(x_val, y_val))
21
  imp_features = (np.abs(svc.coef_) > 0.2).sum()
22
  imp_nodes = np.where(np.abs(svc.coef_) > 0.2)[1]
23
- return svc.coef_, imp_features, imp_nodes
24
  elif method == 'LR':
25
  clf = LogisticRegression(random_state=0, C=C)
26
  clf.fit(x_train, y_train)
27
  print('Val performance logistic regression', clf.score(x_val, y_val))
28
  imp_features = (np.abs(clf.coef_) > 0.15).sum()
29
  imp_nodes = np.where(np.abs(clf.coef_) > 0.15)[1]
30
- return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes
31
 
32
 
33
  def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  device = torch.device('cpu')
35
  G = model.to(device) # type: ignore
36
 
@@ -62,6 +92,16 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
62
  return images, lambdas
63
 
64
  def generate_original_image(z, model):
 
 
 
 
 
 
 
 
 
 
65
  device = torch.device('cpu')
66
  G = model.to(device) # type: ignore
67
  # Labels.
@@ -73,11 +113,28 @@ def generate_original_image(z, model):
73
 
74
 
75
  def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  important_nodes = []
 
77
  vectors = np.zeros((len(concepts), 512))
78
  for i, conc in enumerate(concepts):
79
- vec, _, imp_nodes = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C)
80
  vectors[i,:] = vec
 
81
  important_nodes.append(set(imp_nodes))
82
 
83
  # reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
@@ -89,5 +146,5 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
89
 
90
  # projection = reducer.fit_transform(vectors)
91
  nodes_in_common = set.intersection(*important_nodes)
92
- return vectors, nodes_in_common
93
 
 
7
  import PIL
8
 
9
  def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1):
10
+ """
11
+ The get_separation_space function takes in a type_bin, annotations, and df.
12
+ It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
13
+ It then trains an SVM or logistic regression model on these 200 samples to find a separation space between them.
14
+ The function returns this separation space as well as how many nodes are important in this separation space.
15
+
16
+ :param type_bin: Select the type of abstracts to be used for training
17
+ :param annotations: Access the z_vectors
18
+ :param df: Get the abstracts that are used for training
19
+ :param samples: Determine how many samples to take from the top and bottom of the distribution
20
+ :param method: Specify the classifier to use
21
+ :param C: Control the regularization strength
22
+ :return: The weights of the linear classifier
23
+ :doc-author: Trelent
24
+ """
25
  abstracts = np.array([float(ann) for ann in df[type_bin]])
26
  abstract_idxs = list(np.argsort(abstracts))[:samples]
27
  repr_idxs = list(np.argsort(abstracts))[-samples:]
 
35
  print('Val performance SVM', svc.score(x_val, y_val))
36
  imp_features = (np.abs(svc.coef_) > 0.2).sum()
37
  imp_nodes = np.where(np.abs(svc.coef_) > 0.2)[1]
38
+ return svc.coef_, imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
39
  elif method == 'LR':
40
  clf = LogisticRegression(random_state=0, C=C)
41
  clf.fit(x_train, y_train)
42
  print('Val performance logistic regression', clf.score(x_val, y_val))
43
  imp_features = (np.abs(clf.coef_) > 0.15).sum()
44
  imp_nodes = np.where(np.abs(clf.coef_) > 0.15)[1]
45
+ return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
46
 
47
 
48
  def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
49
+ """
50
+ The regenerate_images function takes a model, z, and decision_boundary as input. It then
51
+ constructs an inverse rotation/translation matrix and passes it to the generator. The generator
52
+ expects this matrix as an inverse to avoid potentially failing numerical operations in the network.
53
+ The function then generates images using G(z_0, label) where z_0 is a linear combination of z and the decision boundary.
54
+
55
+ :param model: Pass in the model to be used for image generation
56
+ :param z: Generate the starting point of the line
57
+ :param decision_boundary: Generate images along the direction of the decision boundary
58
+ :param min_epsilon: Set the minimum value of lambda
59
+ :param max_epsilon: Set the maximum distance from the original image to generate
60
+ :param count: Determine the number of images that are generated
61
+ :return: A list of images and a list of lambdas
62
+ :doc-author: Trelent
63
+ """
64
  device = torch.device('cpu')
65
  G = model.to(device) # type: ignore
66
 
 
92
  return images, lambdas
93
 
94
  def generate_original_image(z, model):
95
+ """
96
+ The generate_original_image function takes in a latent vector and the model,
97
+ and returns an image generated from that latent vector.
98
+
99
+
100
+ :param z: Generate the image
101
+ :param model: Generate the image
102
+ :return: A pil image
103
+ :doc-author: Trelent
104
+ """
105
  device = torch.device('cpu')
106
  G = model.to(device) # type: ignore
107
  # Labels.
 
113
 
114
 
115
  def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1):
116
+ """
117
+ The get_concepts_vectors function takes in a list of concepts, a dictionary of annotations, and the dataframe containing all the images.
118
+ It returns two things:
119
+ 1) A numpy array with shape (len(concepts), 512) where each row is an embedding vector for one concept.
120
+ 2) A set containing all nodes that are important in this separation space.
121
+
122
+ :param concepts: Specify the concepts to be used in the analysis
123
+ :param annotations: Get the annotations for each concept
124
+ :param df: Get the annotations for each concept
125
+ :param samples: Determine the number of samples to use in training the logistic regression model
126
+ :param method: Choose the method used to train the model
127
+ :param C: Control the regularization of the logistic regression
128
+ :return: The vectors of the concepts and the nodes that are in common for all concepts
129
+ :doc-author: Trelent
130
+ """
131
  important_nodes = []
132
+ performances = []
133
  vectors = np.zeros((len(concepts), 512))
134
  for i, conc in enumerate(concepts):
135
+ vec, _, imp_nodes, performance = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C)
136
  vectors[i,:] = vec
137
+ performances.append(performance)
138
  important_nodes.append(set(imp_nodes))
139
 
140
  # reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
 
146
 
147
  # projection = reducer.fit_transform(vectors)
148
  nodes_in_common = set.intersection(*important_nodes)
149
+ return vectors, nodes_in_common, performances
150
 
pages/1_Disentanglement.py CHANGED
@@ -101,10 +101,10 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
101
 
102
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
103
  with output_col_1:
104
- separation_vector, number_important_features, imp_nodes = get_separation_space(concept_id, annotations, ann_df)
105
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
106
  st.write('Concept vector', separation_vector)
107
- header_col_1.write(f'Concept {concept_id} - Number of relevant nodes: {number_important_features}')# - Nodes {",".join(list(imp_nodes))}')
108
 
109
  # ----------------------------- INPUT column 2 & 3 ----------------------------
110
  with input_col_2:
 
101
 
102
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
103
  with output_col_1:
104
+ separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df)
105
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
106
  st.write('Concept vector', separation_vector)
107
+ header_col_1.write(f'Concept {concept_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
108
 
109
  # ----------------------------- INPUT column 2 & 3 ----------------------------
110
  with input_col_2:
pages/2_Concepts_comparison.py CHANGED
@@ -91,10 +91,10 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
91
 
92
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
93
  with output_col_1:
94
- vectors, nodes_in_common = get_concepts_vectors(concept_ids, annotations, ann_df)
95
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
96
  #st.write('Concept vector', separation_vector)
97
- header_col_1.write(f'Concepts {", ".join(concept_ids)} - Relevant nodes in common: {nodes_in_common}')# - Nodes {",".join(list(imp_nodes))}')
98
 
99
  edges = []
100
  for i in range(len(concept_ids)):
 
91
 
92
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
93
  with output_col_1:
94
+ vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df)
95
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
96
  #st.write('Concept vector', separation_vector)
97
+ header_col_1.write(f'Concepts {", ".join(concept_ids)} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
98
 
99
  edges = []
100
  for i in range(len(concept_ids)):
tmp/nx.html DELETED
@@ -1,155 +0,0 @@
1
- <html>
2
- <head>
3
- <meta charset="utf-8">
4
-
5
- <script src="lib/bindings/utils.js"></script>
6
- <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/dist/vis-network.min.css" integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA==" crossorigin="anonymous" referrerpolicy="no-referrer" />
7
- <script src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/vis-network.min.js" integrity="sha512-LnvoEWDFrqGHlHmDD2101OrLcbsfkrzoSpvtSQtxK3RMnRV0eOkhhBN2dXHKRrUU8p2DGRTk35n4O8nWSVe1mQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
8
-
9
-
10
- <center>
11
- <h1></h1>
12
- </center>
13
-
14
- <!-- <link rel="stylesheet" href="../node_modules/vis/dist/vis.min.css" type="text/css" />
15
- <script type="text/javascript" src="../node_modules/vis/dist/vis.js"> </script>-->
16
- <link
17
- href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css"
18
- rel="stylesheet"
19
- integrity="sha384-eOJMYsd53ii+scO/bJGFsiCZc+5NDVN2yr8+0RDqr0Ql0h+rP48ckxlpbzKgwra6"
20
- crossorigin="anonymous"
21
- />
22
- <script
23
- src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
24
- integrity="sha384-JEW9xMcG8R+pH31jmWH6WWP0WintQrMb4s7ZOdauHnUtxwoG2vI5DkLtS3qm9Ekf"
25
- crossorigin="anonymous"
26
- ></script>
27
-
28
-
29
- <center>
30
- <h1></h1>
31
- </center>
32
- <style type="text/css">
33
-
34
- #mynetwork {
35
- width: 100%;
36
- height: 750px;
37
- background-color: #ffffff;
38
- border: 1px solid lightgray;
39
- position: relative;
40
- float: left;
41
- }
42
-
43
-
44
-
45
-
46
-
47
-
48
- </style>
49
- </head>
50
-
51
-
52
- <body>
53
- <div class="card" style="width: 100%">
54
-
55
-
56
- <div id="mynetwork" class="card-body"></div>
57
- </div>
58
-
59
-
60
-
61
-
62
- <script type="text/javascript">
63
-
64
- // initialize global variables.
65
- var edges;
66
- var nodes;
67
- var allNodes;
68
- var allEdges;
69
- var nodeColors;
70
- var originalNodes;
71
- var network;
72
- var container;
73
- var options, data;
74
- var filter = {
75
- item : '',
76
- property : '',
77
- value : []
78
- };
79
-
80
-
81
-
82
-
83
-
84
- // This method is responsible for drawing the graph, returns the drawn network
85
- function drawGraph() {
86
- var container = document.getElementById('mynetwork');
87
-
88
-
89
-
90
- // parsing and collecting nodes and edges from the python
91
- nodes = new vis.DataSet([{"color": "#97c2fc", "id": "Op Art", "label": "Op Art", "shape": "dot", "title": "Op Art"}, {"color": "#97c2fc", "id": "Minimalism", "label": "Minimalism", "shape": "dot", "title": "Minimalism"}, {"color": "#97c2fc", "id": "Surrealism", "label": "Surrealism", "shape": "dot", "title": "Surrealism"}, {"color": "#97c2fc", "id": "Baroque", "label": "Baroque", "shape": "dot", "title": "Baroque"}, {"color": "#97c2fc", "id": "Lithography", "label": "Lithography", "shape": "dot", "title": "Lithography"}, {"color": "#97c2fc", "id": "Woodcut", "label": "Woodcut", "shape": "dot", "title": "Woodcut"}, {"color": "#97c2fc", "id": "etching", "label": "etching", "shape": "dot", "title": "etching"}, {"color": "#97c2fc", "id": "Intaglio", "label": "Intaglio", "shape": "dot", "title": "Intaglio"}]);
92
- edges = new vis.DataSet([{"from": "Op Art", "title": "Op Art to Minimalism similarity 0.432", "to": "Minimalism", "value": 0.432}, {"from": "Op Art", "title": "Op Art to Surrealism similarity -0.086", "to": "Surrealism", "value": -0.086}, {"from": "Op Art", "title": "Op Art to Baroque similarity -0.047", "to": "Baroque", "value": -0.047}, {"from": "Op Art", "title": "Op Art to Lithography similarity 0.054", "to": "Lithography", "value": 0.054}, {"from": "Op Art", "title": "Op Art to Woodcut similarity 0.125", "to": "Woodcut", "value": 0.125}, {"from": "Op Art", "title": "Op Art to etching similarity 0.117", "to": "etching", "value": 0.117}, {"from": "Op Art", "title": "Op Art to Intaglio similarity 0.094", "to": "Intaglio", "value": 0.094}, {"from": "Minimalism", "title": "Minimalism to Surrealism similarity -0.042", "to": "Surrealism", "value": -0.042}, {"from": "Minimalism", "title": "Minimalism to Baroque similarity -0.052", "to": "Baroque", "value": -0.052}, {"from": "Minimalism", "title": "Minimalism to Lithography similarity 0.046", "to": "Lithography", "value": 0.046}, {"from": "Minimalism", "title": "Minimalism to Woodcut similarity 0.069", "to": "Woodcut", "value": 0.069}, {"from": "Minimalism", "title": "Minimalism to etching similarity 0.1", "to": "etching", "value": 0.1}, {"from": "Minimalism", "title": "Minimalism to Intaglio similarity 0.03", "to": "Intaglio", "value": 0.03}, {"from": "Surrealism", "title": "Surrealism to Baroque similarity 0.067", "to": "Baroque", "value": 0.067}, {"from": "Surrealism", "title": "Surrealism to Lithography similarity -0.235", "to": "Lithography", "value": -0.235}, {"from": "Surrealism", "title": "Surrealism to Woodcut similarity -0.16", "to": "Woodcut", "value": -0.16}, {"from": "Surrealism", "title": "Surrealism to etching similarity -0.171", "to": "etching", "value": -0.171}, {"from": "Surrealism", "title": "Surrealism to Intaglio similarity -0.076", "to": "Intaglio", "value": -0.076}, {"from": "Baroque", "title": "Baroque to Lithography similarity -0.125", "to": "Lithography", "value": -0.125}, {"from": "Baroque", "title": "Baroque to Woodcut similarity -0.022", "to": "Woodcut", "value": -0.022}, {"from": "Baroque", "title": "Baroque to etching similarity -0.102", "to": "etching", "value": -0.102}, {"from": "Baroque", "title": "Baroque to Intaglio similarity -0.046", "to": "Intaglio", "value": -0.046}, {"from": "Lithography", "title": "Lithography to Woodcut similarity 0.258", "to": "Woodcut", "value": 0.258}, {"from": "Lithography", "title": "Lithography to etching similarity 0.268", "to": "etching", "value": 0.268}, {"from": "Lithography", "title": "Lithography to Intaglio similarity 0.123", "to": "Intaglio", "value": 0.123}, {"from": "Woodcut", "title": "Woodcut to etching similarity 0.21", "to": "etching", "value": 0.21}, {"from": "Woodcut", "title": "Woodcut to Intaglio similarity 0.209", "to": "Intaglio", "value": 0.209}, {"from": "etching", "title": "etching to Intaglio similarity 0.178", "to": "Intaglio", "value": 0.178}]);
93
-
94
- nodeColors = {};
95
- allNodes = nodes.get({ returnType: "Object" });
96
- for (nodeId in allNodes) {
97
- nodeColors[nodeId] = allNodes[nodeId].color;
98
- }
99
- allEdges = edges.get({ returnType: "Object" });
100
- // adding nodes and edges to the graph
101
- data = {nodes: nodes, edges: edges};
102
-
103
- var options = {
104
- "configure": {
105
- "enabled": false
106
- },
107
- "edges": {
108
- "color": {
109
- "inherit": true
110
- },
111
- "smooth": {
112
- "enabled": true,
113
- "type": "dynamic"
114
- }
115
- },
116
- "interaction": {
117
- "dragNodes": true,
118
- "hideEdgesOnDrag": false,
119
- "hideNodesOnDrag": false
120
- },
121
- "physics": {
122
- "enabled": true,
123
- "stabilization": {
124
- "enabled": true,
125
- "fit": true,
126
- "iterations": 1000,
127
- "onlyDynamicEdges": false,
128
- "updateInterval": 50
129
- }
130
- }
131
- };
132
-
133
-
134
-
135
-
136
-
137
-
138
- network = new vis.Network(container, data, options);
139
-
140
-
141
-
142
-
143
-
144
-
145
-
146
-
147
-
148
-
149
- return network;
150
-
151
- }
152
- drawGraph();
153
- </script>
154
- </body>
155
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
view_predictions.ipynb CHANGED
The diff for this file is too large to render. See raw diff