ludusc commited on
Commit
d6f2aad
1 Parent(s): a426c90

small change

Browse files
pages/3_Oxford_Vases_Disentanglement.py CHANGED
@@ -56,7 +56,7 @@ with input_col_1:
56
  concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=['AMPHORA', 'CHALICE'])
57
 
58
  st.write('**Choose a latent space to disentangle**')
59
- space_id = st.selectbox('Space:', tuple(['Z', 'W']))
60
 
61
  choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
62
 
@@ -90,7 +90,7 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
90
 
91
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
92
  with output_col_1:
93
- separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id)
94
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
95
  st.write('Concept vector', separation_vector)
96
  header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
@@ -108,7 +108,7 @@ with input_col_2:
108
  random_id = st.form_submit_button('Generate a random image')
109
 
110
  if random_id:
111
- image_id = random.randint(0, 50000)
112
  st.session_state.image_id = image_id
113
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
114
 
@@ -143,13 +143,14 @@ else:
143
  original_image_vec = annotations['w_vectors'][st.session_state.image_id]
144
 
145
  img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
 
146
  # input_image = original_image_dict['image']
147
  # input_label = original_image_dict['label']
148
  # input_id = original_image_dict['id']
149
 
150
  with smoothgrad_col_3:
151
  st.image(img)
152
- smooth_head_3.write(f'Base image')
153
 
154
 
155
  images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
 
56
  concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=['AMPHORA', 'CHALICE'])
57
 
58
  st.write('**Choose a latent space to disentangle**')
59
+ space_id = st.selectbox('Space:', tuple(['Z', 'W']), default='W')
60
 
61
  choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
62
 
 
90
 
91
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
92
  with output_col_1:
93
+ separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id, samples=200)
94
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
95
  st.write('Concept vector', separation_vector)
96
  header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
 
108
  random_id = st.form_submit_button('Generate a random image')
109
 
110
  if random_id:
111
+ image_id = random.randint(0, 20000)
112
  st.session_state.image_id = image_id
113
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
114
 
 
143
  original_image_vec = annotations['w_vectors'][st.session_state.image_id]
144
 
145
  img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
146
+ top_pred = ann_df.iloc[st.session_state.image_id].idxmax()
147
  # input_image = original_image_dict['image']
148
  # input_label = original_image_dict['label']
149
  # input_id = original_image_dict['id']
150
 
151
  with smoothgrad_col_3:
152
  st.image(img)
153
+ smooth_head_3.write(f'Base image, predicted as {top_pred}')
154
 
155
 
156
  images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
pages/4_todo.py DELETED
@@ -1,56 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
-
4
- from backend.utils import load_dataset, use_container_width_percentage
5
-
6
- st.set_page_config(layout='wide')
7
-
8
- st.title('ImageNet-1k')
9
- st.markdown('This page shows the summary of 50,000 images in the validation set of [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)')
10
-
11
- # SCREEN_WIDTH, SCREEN_HEIGHT = 2560, 1664
12
-
13
- with st.spinner("Loading dataset..."):
14
- dataset_dict = {}
15
- for data_index in range(5):
16
- dataset_dict[data_index] = load_dataset(data_index)
17
-
18
- imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
19
-
20
- class_labels = imagenet_df.ClassLabel.unique().tolist()
21
- class_labels.sort()
22
- selected_classes = st.multiselect('Class filter: ', options=['All'] + class_labels)
23
- if not ('All' in selected_classes or len(selected_classes) == 0):
24
- imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
25
- # st.write(class_labels)
26
-
27
- col1, col2 = st.columns([2, 1])
28
- with col1:
29
- st.dataframe(imagenet_df)
30
- use_container_width_percentage(100)
31
-
32
- with col2:
33
- st.text_area('Type anything here to copy later :)')
34
- image = None
35
- with st.form("display image"):
36
- img_index = st.text_input('Image ID to display')
37
-
38
- submitted = st.form_submit_button('Display this image')
39
- error_container = st.empty()
40
-
41
- if submitted:
42
- try:
43
- img_index = int(img_index)
44
- if img_index > 50000-1 or img_index < 0:
45
- error_container.error('The Image ID must be in range from 0 to 49999', icon="🚫")
46
- else:
47
- image = dataset_dict[img_index//10_000][img_index%10_000]['image']
48
- class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
49
- class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
50
- except ValueError:
51
- error_container.error('Please enter an integer number for Image ID', icon = "🚫")
52
-
53
- if image != None:
54
- st.image(image)
55
- st.write('**Class label:** ', class_label)
56
- st.write('\n**Class id:** ', str(class_id))