ludusc commited on
Commit
944e312
1 Parent(s): 5b6e083

textile disentanglement

Browse files
backend/disentangle_concepts.py CHANGED
@@ -12,67 +12,7 @@ from PIL import Image, ImageColor
12
  from .color_annotations import extract_color
13
 
14
 
15
-
16
- def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=0.1, latent_space='Z'):
17
- """
18
- The get_separation_space function takes in a type_bin, annotations, and df.
19
- It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
20
- It then trains an SVM or logistic regression model on these 200 samples to find a separation space between them.
21
- The function returns this separation space as well as how many nodes are important in this separation space.
22
-
23
- :param type_bin: Select the type of abstracts to be used for training
24
- :param annotations: Access the z_vectors
25
- :param df: Get the abstracts that are used for training
26
- :param samples: Determine how many samples to take from the top and bottom of the distribution
27
- :param method: Specify the classifier to use
28
- :param C: Control the regularization strength
29
- :return: The weights of the linear classifier
30
- :doc-author: Trelent
31
- """
32
-
33
- if latent_space == 'Z':
34
- col = 'z_vectors'
35
- else:
36
- col = 'w_vectors'
37
-
38
- if len(type_bin) == 1:
39
- type_bin = type_bin[0]
40
- if type(type_bin) == str:
41
- abstracts = np.array([float(ann) for ann in df[type_bin]])
42
- abstract_idxs = list(np.argsort(abstracts))[:samples]
43
- repr_idxs = list(np.argsort(abstracts))[-samples:]
44
- X = np.array([annotations[col][i] for i in abstract_idxs+repr_idxs])
45
- elif len(type_bin) == 2:
46
- print('Using two concepts for separation space')
47
- first_concept = np.array([float(ann) for ann in df[type_bin[0]]])
48
- second_concept = np.array([float(ann) for ann in df[type_bin[1]]])
49
- first_idxs = list(np.argsort(first_concept))[:samples]
50
- second_idxs = list(np.argsort(second_concept))[:samples]
51
- X = np.array([annotations[col][i] for i in first_idxs+second_idxs])
52
- else:
53
- print('Error: type_bin must be either a string or a list of strings of len 2')
54
- return
55
-
56
- X = X.reshape((2*samples, 512))
57
- y = np.array([1]*samples + [0]*samples)
58
- x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
59
- if method == 'SVM':
60
- svc = SVC(gamma='auto', kernel='linear', random_state=0, C=C)
61
- svc.fit(x_train, y_train)
62
- print('Val performance SVM', svc.score(x_val, y_val))
63
- imp_features = (np.abs(svc.coef_) > 0.2).sum()
64
- imp_nodes = np.where(np.abs(svc.coef_) > 0.2)[1]
65
- return svc.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
66
- elif method == 'LR':
67
- clf = LogisticRegression(random_state=0, C=C)
68
- clf.fit(x_train, y_train)
69
- print('Val performance logistic regression', clf.score(x_val, y_val))
70
- imp_features = (np.abs(clf.coef_) > 0.15).sum()
71
- imp_nodes = np.where(np.abs(clf.coef_) > 0.15)[1]
72
- return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
73
-
74
-
75
- def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z', layers=None, number=3):
76
  """
77
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
78
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
@@ -91,46 +31,39 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
91
  device = torch.device('cpu')
92
  G = model.to(device) # type: ignore
93
 
94
- if False:
95
- decision_boundary = z - (np.dot(z, decision_boundary.T) / np.dot(decision_boundary, decision_boundary.T)) * decision_boundary
96
  # Labels.
97
  label = torch.zeros([1, G.c_dim], device=device)
98
 
99
  z = torch.from_numpy(z.copy()).to(device)
100
- decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
 
 
 
 
 
101
 
102
- repetitions = 16 if number == 3 else 14
103
- lambdas = np.linspace(min_epsilon, max_epsilon, count)
104
- images = []
105
- # Generate images.
106
- for _, lambda_ in enumerate(tqdm(lambdas)):
107
- z_0 = z + lambda_ * decision_boundary
108
- if latent_space == 'Z':
109
- W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
110
- W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
111
- else:
112
- W_0 = z_0.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
113
- W = z.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
114
 
115
- if layers:
116
- W_f = torch.empty_like(W).copy_(W).to(torch.float32)
117
- W_f[:, layers, :] = W_0[:, layers, :]
118
- img = G.synthesis(W_f, noise_mode='const')
119
- else:
120
- img = G.synthesis(W_0, noise_mode='const')
 
 
 
 
 
 
 
121
 
122
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
123
- images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
124
 
125
- return images, lambdas
126
-
127
-
128
- def generate_joint_effect(model, z, decision_boundaries, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
129
- decision_boundary_joint = np.sum(decision_boundaries, axis=0)
130
- print(decision_boundary_joint.shape)
131
- return regenerate_images(model, z, decision_boundary_joint, min_epsilon=min_epsilon, max_epsilon=max_epsilon, count=count, latent_space=latent_space)
132
 
133
- def generate_original_image(z, model, latent_space='Z', number=3):
 
134
  """
135
  The generate_original_image function takes in a latent vector and the model,
136
  and returns an image generated from that latent vector.
@@ -141,7 +74,7 @@ def generate_original_image(z, model, latent_space='Z', number=3):
141
  :return: A pil image
142
  :doc-author: Trelent
143
  """
144
- repetitions = 16 if number == 3 else 14
145
 
146
  device = torch.device('cpu')
147
  G = model.to(device) # type: ignore
@@ -152,304 +85,8 @@ def generate_original_image(z, model, latent_space='Z', number=3):
152
  img = G(z, label, truncation_psi=1, noise_mode='const')
153
  else:
154
  W = torch.from_numpy(np.repeat(z, repetitions, axis=0).reshape(1, repetitions, z.shape[1]).copy()).to(device)
155
- print(W.shape)
156
  img = G.synthesis(W, noise_mode='const')
157
 
158
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
159
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
160
 
161
-
162
- def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1, latent_space='Z'):
163
- """
164
- The get_concepts_vectors function takes in a list of concepts, a dictionary of annotations, and the dataframe containing all the images.
165
- It returns two things:
166
- 1) A numpy array with shape (len(concepts), 512) where each row is an embedding vector for one concept.
167
- 2) A set containing all nodes that are important in this separation space.
168
-
169
- :param concepts: Specify the concepts to be used in the analysis
170
- :param annotations: Get the annotations for each concept
171
- :param df: Get the annotations for each concept
172
- :param samples: Determine the number of samples to use in training the logistic regression model
173
- :param method: Choose the method used to train the model
174
- :param C: Control the regularization of the logistic regression
175
- :return: The vectors of the concepts and the nodes that are in common for all concepts
176
- :doc-author: Trelent
177
- """
178
- important_nodes = []
179
- performances = []
180
- vectors = np.zeros((len(concepts), 512))
181
- for i, conc in enumerate(concepts):
182
- vec, _, imp_nodes, performance = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C, latent_space=latent_space)
183
- vectors[i,:] = vec
184
- performances.append(performance)
185
- important_nodes.append(set(imp_nodes))
186
-
187
- # reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
188
- # n_components=3, # default 2, The dimension of the space to embed into.
189
- # min_dist=0.1, # default 0.1, The effective minimum distance between embedded points.
190
- # spread=2.0, # default 1.0, The effective scale of embedded points. In combination with ``min_dist`` this determines how clustered/clumped the embedded points are.
191
- # random_state=0, # default: None, If int, random_state is the seed used by the random number generator;
192
- # )
193
-
194
- # projection = reducer.fit_transform(vectors)
195
- nodes_in_common = set.intersection(*important_nodes)
196
- return vectors, nodes_in_common, performances
197
-
198
-
199
- def get_verification_score(color_id, decision_boundary, model, annotations, samples=100, latent_space='W'):
200
- listlen = len(annotations['fname'])
201
- items = random.sample(range(listlen), samples)
202
- hue_low = color_id * 256 / 12
203
- hue_high = (color_id + 1) * 256 / 12
204
- hue_mean = (hue_low + hue_high) / 2
205
- print(int(hue_low), int(hue_high), int(hue_mean))
206
- distances = []
207
- distances_orig = []
208
- for iterator in tqdm(items):
209
- if latent_space == 'Z':
210
- z = annotations['z_vectors'][iterator]
211
- else:
212
- z = annotations['w_vectors'][iterator]
213
-
214
- images, lambdas = regenerate_images(model, z, decision_boundary, min_epsilon=0, max_epsilon=1, count=2, latent_space=latent_space)
215
- colors_orig = extract_color(images[0], 5, 1, None)
216
- h_old, s_old, v_old = ImageColor.getcolor(colors_orig[0], 'HSV')
217
- colors_new = extract_color(images[1], 5, 1, None)
218
- h_new, s_new, v_new = ImageColor.getcolor(colors_new[0], 'HSV')
219
- print(h_old, h_new)
220
- distance = np.abs(hue_mean - h_new)
221
- distances.append(distance)
222
- distance_orig = np.abs(hue_mean - h_old)
223
- distances_orig.append(distance_orig)
224
-
225
- return np.round(np.mean(np.array(distances)), 4), np.round(np.mean(np.array(distances_orig)), 4)
226
-
227
-
228
- def get_verification_score_clip(concept, decision_boundary, model, annotations, samples=100, latent_space='Z'):
229
- import open_clip
230
- import os
231
- import random
232
- from tqdm import tqdm
233
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
234
-
235
-
236
- model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
237
- tokenizer = open_clip.get_tokenizer('ViT-L-14')
238
-
239
- # Prepare the text queries
240
- #@markdown _in the form pre_prompt {label}_:
241
- pre_prompt = "Artwork, " #@param {type:"string"}
242
- text_descriptions = [f"{pre_prompt}{label}" for label in [concept]]
243
- text_tokens = tokenizer(text_descriptions)
244
-
245
-
246
- listlen = len(annotations['fname'])
247
- items = random.sample(range(listlen), samples)
248
- changes = []
249
- for iterator in tqdm(items):
250
- chunk_imgs = []
251
- chunk_ids = []
252
-
253
- if latent_space == 'Z':
254
- z = annotations['z_vectors'][iterator]
255
- else:
256
- z = annotations['w_vectors'][iterator]
257
- images, lambdas = regenerate_images(model, z, decision_boundary, min_epsilon=0, max_epsilon=1, count=2, latent_space=latent_space)
258
- for im,l in zip(images, lambdas):
259
-
260
- chunk_imgs.append(preprocess(im.convert("RGB")))
261
- chunk_ids.append(l)
262
-
263
- image_input = torch.tensor(np.stack(chunk_imgs))
264
-
265
- with torch.no_grad(), torch.cuda.amp.autocast():
266
- text_features = model_clip.encode_text(text_tokens).float()
267
- image_features = model_clip.encode_image(image_input).float()
268
-
269
- # Rescale features
270
- image_features /= image_features.norm(dim=-1, keepdim=True)
271
- text_features /= text_features.norm(dim=-1, keepdim=True)
272
-
273
- # Analyze featues
274
- text_probs = (100.0 * image_features.cpu().numpy() @ text_features.cpu().numpy().T)#.softmax(dim=-1)
275
-
276
- change = max(text_probs[1][0].item() - text_probs[0][0].item(), 0)
277
- changes.append(change)
278
-
279
- return np.round(np.mean(np.array(changes)), 4)
280
-
281
-
282
-
283
- def tohsv(df):
284
- df['H1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
285
- df['H2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
286
- df['H3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
287
-
288
- df['S1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
289
- df['S2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
290
- df['S3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
291
-
292
- df['V1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
293
- df['V2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
294
- df['V3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
295
- return df
296
-
297
-
298
- def rest_from_style(x, styles, layer):
299
- dtype = torch.float16 if (getattr(model.synthesis, layer).use_fp16 and device=='cuda') else torch.float32
300
- if getattr(model.synthesis, layer).is_torgb:
301
- print(layer, getattr(model.synthesis, layer).is_torgb)
302
- weight_gain = 1 / np.sqrt(getattr(model.synthesis, layer).in_channels * (getattr(model.synthesis, layer).conv_kernel ** 2))
303
- styles = styles * weight_gain
304
- input_gain = getattr(model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)
305
- # Execute modulated conv2d.
306
- x = modulated_conv2d(x=x.to(dtype), w=getattr(model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),
307
- padding=getattr(model.synthesis, layer).conv_kernel-1, demodulate=(not getattr(model.synthesis, layer).is_torgb), input_gain=input_gain.to(dtype))
308
- # Execute bias, filtered leaky ReLU, and clamping.
309
- gain = 1 if getattr(model.synthesis, layer).is_torgb else np.sqrt(2)
310
- slope = 1 if getattr(model.synthesis, layer).is_torgb else 0.2
311
- x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(model.synthesis, layer).up_filter, fd=getattr(model.synthesis, layer).down_filter,
312
- b=getattr(model.synthesis, layer).bias.to(x.dtype),
313
- up=getattr(model.synthesis, layer).up_factor, down=getattr(model.synthesis, layer).down_factor,
314
- padding=getattr(model.synthesis, layer).padding,
315
- gain=gain, slope=slope, clamp=getattr(model.synthesis, layer).conv_clamp)
316
- return x
317
-
318
-
319
- def getS(w):
320
- w_torch = torch.from_numpy(w).to('cpu')
321
- W = w_torch.expand((16, -1)).unsqueeze(0)
322
- s = []
323
- s.append(model.synthesis.input.affine(W[0, 0].unsqueeze(0)).numpy())
324
- s.append(model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)).numpy())
325
- s.append(model.synthesis.L1_36_512.affine(W[0, 2].unsqueeze(0)).numpy())
326
- s.append(model.synthesis.L2_36_512.affine(W[0, 3].unsqueeze(0)).numpy())
327
- s.append(model.synthesis.L3_52_512.affine(W[0, 4].unsqueeze(0)).numpy())
328
- s.append(model.synthesis.L4_52_512.affine(W[0, 5].unsqueeze(0)).numpy())
329
- s.append(model.synthesis.L5_84_512.affine(W[0, 6].unsqueeze(0)).numpy())
330
- s.append(model.synthesis.L6_84_512.affine(W[0, 7].unsqueeze(0)).numpy())
331
- s.append(model.synthesis.L7_148_512.affine(W[0, 8].unsqueeze(0)).numpy())
332
- s.append(model.synthesis.L8_148_512.affine(W[0, 9].unsqueeze(0)).numpy())
333
- s.append(model.synthesis.L9_148_362.affine(W[0, 10].unsqueeze(0)).numpy())
334
- s.append(model.synthesis.L10_276_256.affine(W[0, 11].unsqueeze(0)).numpy())
335
- s.append(model.synthesis.L11_276_181.affine(W[0, 12].unsqueeze(0)).numpy())
336
- s.append(model.synthesis.L12_276_128.affine(W[0, 13].unsqueeze(0)).numpy())
337
- s.append(model.synthesis.L13_256_128.affine(W[0, 14].unsqueeze(0)).numpy())
338
- s.append(model.synthesis.L14_256_3.affine(W[0, 15].unsqueeze(0)).numpy())
339
- return s
340
-
341
- def detect_attribute_specific_channels(positives, all, sign=False):
342
- """ Formula from StyleSpace Analysis """
343
- mp = np.mean(all, axis=0)
344
- sp = np.std(all, axis=0)
345
- de = (positives - mp) / sp
346
- meu = np.mean(de, axis=0)
347
- seu = np.std(de, axis=0)
348
- if sign:
349
- thetau = meu / seu
350
- else:
351
- thetau = np.abs(meu) / seu
352
- return thetau
353
-
354
- def all_variance_based_disentanglements(labels, x, y, k=10, sign=False, cutout=0.28):
355
- seps = []
356
- sorted_vals = []
357
- for lbl in labels:
358
- positives = x[np.where(y == lbl)]
359
- variations = detect_attribute_specific_channels(positives, x, sign=sign)
360
- if sign:
361
- argsorted_vars_pos = np.argsort(variations)[-k//2:]
362
- # print(argsorted_vars_pos)
363
- argsorted_vars_neg = np.argsort(variations)[:k//2]
364
- if cutout:
365
- beyond_cutout = np.where(np.abs(variations) > cutout)
366
- # print(beyond_cutout)
367
- argsorted_vars_pos_int = np.intersect1d(argsorted_vars_pos, beyond_cutout)
368
- argsorted_vars_neg_int = np.intersect1d(argsorted_vars_neg, beyond_cutout)
369
- # print(argsorted_vars_pos)
370
- if len(argsorted_vars_neg_int) > 0:
371
- argsorted_vars_neg = np.array(argsorted_vars_neg_int)
372
- if len(argsorted_vars_pos_int) > 0:
373
- argsorted_vars_pos = np.array(argsorted_vars_pos_int)
374
-
375
-
376
- else:
377
- argsorted_vars = np.argsort(variations)[-k:]
378
-
379
-
380
- sorted_vals.append(np.sort(variations))
381
- separation_vector_onehot /= np.linalg.norm(separation_vector_onehot)
382
- seps.append(separation_vector_onehot)
383
- return seps, sorted_vals
384
-
385
- def generate_flexible_images(w, change_vectors, lambdas=1, device='cpu'):
386
- w_torch = torch.from_numpy(w).to('cpu')
387
- if len(change_vectors) != 17:
388
- w_torch = w_torch + lambdas * change_vectors[0]
389
- W = w_torch.expand((16, -1)).unsqueeze(0)
390
-
391
- x = model.synthesis.input(W[0,0].unsqueeze(0))
392
- for i, layer in enumerate(layers):
393
- if i < 2:
394
- continue
395
- style = getattr(model.synthesis, layer).affine(W[0, i-1].unsqueeze(0))
396
- if len(change_vectors) != 17:
397
- change = torch.from_numpy(change_vectors[i].copy()).unsqueeze(0).to(device)
398
- style = torch.add(style, change, alpha=lambdas)
399
- x = rest_from_style(x, style, layer)
400
-
401
- if model.synthesis.output_scale != 1:
402
- x = x * model.synthesis.output_scale
403
-
404
- img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
405
- img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
406
-
407
- return img
408
-
409
- def get_original_pos(top_positions, bottom_positions=None, space='s', sign=True,
410
- shapes=[[512, 4, 512, 512, 512, 512, 512, 512, 512,
411
- 512, 512, 512, 362, 256, 181, 128, 128]],
412
- layers=['w', 'input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512',
413
- 'L4_52_512', 'L5_84_512', 'L6_84_512', 'L7_148_512', 'L8_148_512',
414
- 'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128',
415
- 'L13_256_128', 'L14_256_3'], ):
416
- if space == 's':
417
- current_idx = 0
418
- vectors = []
419
- for i, (leng, layer) in enumerate(zip(shapes, layers)):
420
- arr = np.zeros(leng)
421
- for top_position in top_positions:
422
- if top_position >= current_idx and top_position < current_idx + leng:
423
- arr[top_position - current_idx] = 1
424
- for bottom_position in bottom_positions:
425
- if sign:
426
- if bottom_position >= current_idx and bottom_position < current_idx + leng:
427
- arr[bottom_position - current_idx] = 1
428
- arr = arr / (np.linalg.norm(arr) + 0.000001)
429
- vectors.append(arr)
430
- current_idx += leng
431
- else:
432
- if sign:
433
- vectors = np.zeros(512)
434
- vectors[top_positions] = 1
435
- vectors[bottom_positions] = -1
436
- else:
437
- vectors = np.zeros(512)
438
- vectors[top_positions] = 1
439
- return vectors
440
-
441
- def getX(annotations, space='s'):
442
- if space == 'x':
443
- X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))
444
- elif space == 's':
445
- concat_v = []
446
- for i in range(len(annotations['w_vectors'])):
447
- concat_v.append(np.concatenate([annotations['w_vectors'][i]] + annotations['s_vectors'][i], axis=1))
448
-
449
- X = np.array(concat_v)
450
- X = X[:, 0, :]
451
- print(X.shape)
452
-
453
- return X
454
-
455
-
 
12
  from .color_annotations import extract_color
13
 
14
 
15
+ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
18
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
 
31
  device = torch.device('cpu')
32
  G = model.to(device) # type: ignore
33
 
 
 
34
  # Labels.
35
  label = torch.zeros([1, G.c_dim], device=device)
36
 
37
  z = torch.from_numpy(z.copy()).to(device)
38
+ repetitions = 16
39
+ z_0 = z.copy()
40
+
41
+ for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
42
+ decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
43
+ z_0 = z_0 + int(lmbd) * decision_boundary
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ if latent_space == 'Z':
47
+ W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
48
+ # W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
49
+ else:
50
+ W_0 = z_0.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
51
+ # W = z.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
52
+
53
+ # if layers:
54
+ # W_f = torch.empty_like(W).copy_(W).to(torch.float32)
55
+ # W_f[:, layers, :] = W_0[:, layers, :]
56
+ # img = G.synthesis(W_f, noise_mode='const')
57
+ # else:
58
+ img = G.synthesis(W_0, noise_mode='const')
59
 
60
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
61
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
62
 
63
+ return img
 
 
 
 
 
 
64
 
65
+
66
+ def generate_original_image(z, model, latent_space='W'):
67
  """
68
  The generate_original_image function takes in a latent vector and the model,
69
  and returns an image generated from that latent vector.
 
74
  :return: A pil image
75
  :doc-author: Trelent
76
  """
77
+ repetitions = 16
78
 
79
  device = torch.device('cpu')
80
  G = model.to(device) # type: ignore
 
85
  img = G(z, label, truncation_psi=1, noise_mode='const')
86
  else:
87
  W = torch.from_numpy(np.repeat(z, repetitions, axis=0).reshape(1, repetitions, z.shape[1]).copy()).to(device)
 
88
  img = G.synthesis(W, noise_mode='const')
89
 
90
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
91
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_Textiles_Disentanglement.py CHANGED
@@ -31,11 +31,21 @@ annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
31
  with open(annotations_file, 'rb') as f:
32
  annotations = pickle.load(f)
33
 
34
- COLORS_LIST = []
 
 
 
 
 
 
 
 
 
 
35
  if 'image_id' not in st.session_state:
36
  st.session_state.image_id = 0
37
  if 'color_ids' not in st.session_state:
38
- st.session_state.concept_ids =['AMPHORA']
39
  if 'space_id' not in st.session_state:
40
  st.session_state.space_id = 'W'
41
 
@@ -47,61 +57,6 @@ st.header('Input')
47
  input_col_1, input_col_2, input_col_3 = st.columns(3)
48
  # --------------------------- INPUT column 1 ---------------------------
49
  with input_col_1:
50
- with st.form('text_form'):
51
-
52
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
53
- st.write('**Choose two options to disentangle**')
54
- type_col = st.selectbox('Concept category:', tuple(['Provenance', 'Shape Name', 'Fabric', 'Technique']))
55
-
56
- ann_df = pd.read_csv(f'./data/vase_annotated_files/sim_{type_col}_seeds0000-20000.csv')
57
- labels = list(ann_df.columns)
58
- labels.remove('ID')
59
- labels.remove('Unnamed: 0')
60
-
61
- concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=[labels[2], labels[3]])
62
-
63
- st.write('**Choose a latent space to disentangle**')
64
- space_id = st.selectbox('Space:', tuple(['W', 'Z']))
65
-
66
- choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
67
-
68
- if choose_text_button:
69
- concept_ids = list(concept_ids)
70
- st.session_state.concept_ids = concept_ids
71
- space_id = str(space_id)
72
- st.session_state.space_id = space_id
73
- # st.write(image_id, st.session_state.image_id)
74
-
75
- # ---------------------------- SET UP OUTPUT ------------------------------
76
- epsilon_container = st.empty()
77
- st.header('Output')
78
- st.subheader('Concept vector')
79
-
80
- # perform attack container
81
- # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
82
- # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
83
- header_col_1, header_col_2 = st.columns([5,1])
84
- output_col_1, output_col_2 = st.columns([5,1])
85
-
86
- st.subheader('Derivations along the concept vector')
87
-
88
- # prediction error container
89
- error_container = st.empty()
90
- smoothgrad_header_container = st.empty()
91
-
92
- # smoothgrad container
93
- smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
94
- smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
95
-
96
- # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
97
- with output_col_1:
98
- separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id, samples=150)
99
- # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
100
- st.write('Concept vector', separation_vector)
101
- header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
102
-
103
- # ----------------------------- INPUT column 2 & 3 ----------------------------
104
- with input_col_2:
105
  with st.form('image_form'):
106
 
107
  # image_id = st.number_input('Image ID: ', format='%d', step=1)
@@ -113,34 +68,83 @@ with input_col_2:
113
  random_id = st.form_submit_button('Generate a random image')
114
 
115
  if random_id:
116
- image_id = random.randint(0, 20000)
117
  st.session_state.image_id = image_id
118
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
119
 
120
  if choose_image_button:
121
  image_id = int(image_id)
122
  st.session_state.image_id = int(image_id)
123
- # st.write(image_id, st.session_state.image_id)
124
 
125
- with input_col_3:
126
- with st.form('Variate along the disentangled concept'):
 
 
 
 
127
  st.write('**Set range of change**')
128
- chosen_epsilon_input = st.empty()
129
- epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
130
- epsilon_button = st.form_submit_button('Choose the defined lambda')
131
- st.write('**Select hierarchical levels to manipulate**')
132
- layers = st.multiselect('Layers:', tuple(range(14)))
133
- if len(layers) == 0:
134
- layers = None
135
- print(layers)
136
- layers_button = st.form_submit_button('Choose the defined layers')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
- # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
 
 
 
140
 
141
- #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
142
- with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
143
- model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  if st.session_state.space_id == 'Z':
146
  original_image_vec = annotations['z_vectors'][st.session_state.image_id]
@@ -149,30 +153,9 @@ else:
149
 
150
  img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
151
 
152
- top_pred = ann_df.loc[st.session_state.image_id, labels].astype(float).idxmax()
153
- # input_image = original_image_dict['image']
154
- # input_label = original_image_dict['label']
155
- # input_id = original_image_dict['id']
156
-
157
- with smoothgrad_col_3:
158
  st.image(img)
159
- smooth_head_3.write(f'Base image, predicted as {top_pred}')
160
-
161
-
162
- images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
163
-
164
- with smoothgrad_col_1:
165
- st.image(images[0])
166
- smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
167
-
168
- with smoothgrad_col_2:
169
- st.image(images[1])
170
- smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
171
-
172
- with smoothgrad_col_4:
173
- st.image(images[3])
174
- smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
175
 
176
- with smoothgrad_col_5:
177
- st.image(images[4])
178
- smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
 
31
  with open(annotations_file, 'rb') as f:
32
  annotations = pickle.load(f)
33
 
34
+ concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
35
+ concept_vectors['vector'] = np.array([float(x) for x in concept_vectors['vector'].str.split(', ')])
36
+ concept_vectors['score'] = concept_vectors['score'].astype(float)
37
+ concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
38
+ print(concept_vectors[['vector', 'score']])
39
+
40
+ with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
41
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
42
+
43
+ COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink']
44
+
45
  if 'image_id' not in st.session_state:
46
  st.session_state.image_id = 0
47
  if 'color_ids' not in st.session_state:
48
+ st.session_state.concept_ids = COLORS_LIST[-1]
49
  if 'space_id' not in st.session_state:
50
  st.session_state.space_id = 'W'
51
 
 
57
  input_col_1, input_col_2, input_col_3 = st.columns(3)
58
  # --------------------------- INPUT column 1 ---------------------------
59
  with input_col_1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with st.form('image_form'):
61
 
62
  # image_id = st.number_input('Image ID: ', format='%d', step=1)
 
68
  random_id = st.form_submit_button('Generate a random image')
69
 
70
  if random_id:
71
+ image_id = random.randint(0, 100000)
72
  st.session_state.image_id = image_id
73
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
74
 
75
  if choose_image_button:
76
  image_id = int(image_id)
77
  st.session_state.image_id = int(image_id)
 
78
 
79
+ with input_col_2:
80
+ with st.form('text_form'):
81
+
82
+ st.write('**Choose color to vary**')
83
+ type_col = st.selectbox('Color:', tuple(COLORS_LIST), value=st.session_state.concepts_ids)
84
+
85
  st.write('**Set range of change**')
86
+ chosen_color_lambda_input = st.empty()
87
+ color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
88
+ color_lambda_button = st.form_submit_button('Choose the defined lambda')
89
+
90
+ if choose_text_button:
91
+ st.session_state.concept_ids = type_col
92
+ st.session_state.space_id = space_id
93
+
94
+ with input_col_3:
95
+ with st.form('text_form'):
96
+
97
+ st.write('**Saturation variation**')
98
+ chosen_saturation_lambda_input = st.empty()
99
+ saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=0, step=1)
100
+ saturation_lambda_button = st.form_submit_button('Choose the defined lambda')
101
+
102
+ st.write('**Value variation**')
103
+ chosen_value_lambda_input = st.empty()
104
+ value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1)
105
+ value_lambda_button = st.form_submit_button('Choose the defined lambda')
106
+
107
+ # with input_col_4:
108
+ # with st.form('Network specifics:'):
109
+ # st.write('**Choose a latent space to use**')
110
+ # space_id = st.selectbox('Space:', tuple(['W']))
111
+ # choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
112
+
113
+ # st.write('**Select hierarchical levels to manipulate**')
114
+ # layers = st.multiselect('Layers:', tuple(range(14)))
115
+ # if len(layers) == 0:
116
+ # layers = None
117
+ # print(layers)
118
+ # layers_button = st.form_submit_button('Choose the defined layers')
119
 
120
 
121
+ # ---------------------------- SET UP OUTPUT ------------------------------
122
+ epsilon_container = st.empty()
123
+ st.header('Image Manipulation')
124
+ st.subheader('Using selected directions')
125
 
126
+ header_col_1, header_col_2 = st.columns([1,1])
127
+ output_col_1, output_col_2 = st.columns([1,1])
128
+
129
+ # # prediction error container
130
+ # error_container = st.empty()
131
+ # smoothgrad_header_container = st.empty()
132
+
133
+ # # smoothgrad container
134
+ # smooth_head_1, smooth_head_2, = st.columns([1,1,])
135
+ # smoothgrad_col_1, smoothgrad_col_2 = st.columns([1,1])
136
+
137
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
138
+ with header_col_1:
139
+ st.write(f'Original image')
140
+
141
+ with header_col_2:
142
+ color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concepts_ids].loc[0, ['vector', 'score']]
143
+ saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].loc[0, ['vector', 'score']]
144
+ value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].loc[0, ['vector', 'score']]
145
+ st.write(f'Change in {st.session_state.concepts_ids} of {np.round(color_lambda, 2)}, in saturation of {np.round(saturation_lambda, 2)}, in value of {np.round(value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation}, value vector: {performance_value}')
146
+
147
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
148
 
149
  if st.session_state.space_id == 'Z':
150
  original_image_vec = annotations['z_vectors'][st.session_state.image_id]
 
153
 
154
  img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
155
 
156
+ with output_col_1:
 
 
 
 
 
157
  st.image(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ with output_col_2:
160
+ image_updated = generate_composite_images(model, original_image_vec, [separation_vector_color, saturation_separation_vector, value_separation_vector], lambdas=[color_lambda, saturation_lambda, value_lambda])
161
+ st.image(image_updated)