import streamlit as st import streamlit.components.v1 as components import pickle import pandas as pd import numpy as np from pyvis.network import Network import networkx as nx from sklearn.metrics.pairwise import cosine_similarity from matplotlib.backends.backend_agg import RendererAgg from backend.disentangle_concepts import * _lock = RendererAgg.lock HIGHTLIGHT_COLOR = '#e7bcc5' st.set_page_config(layout='wide') st.title('Comparison among concept vectors') st.write('> **How do the concept vectors relate to each other?**') st.write('> **What is their join impact on the image?**') st.write("""Description to write""") annotations_file = './data/annotated_files/seeds0000-100000.pkl' with open(annotations_file, 'rb') as f: annotations = pickle.load(f) ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-100000.csv') concepts = './data/concepts.txt' with open(concepts) as f: labels = [line.strip() for line in f.readlines()] if 'image_id' not in st.session_state: st.session_state.image_id = 0 if 'concept_ids' not in st.session_state: st.session_state.concept_ids = ['Abstract', 'Representational'] # def on_change_random_input(): # st.session_state.image_id = st.session_state.image_id # ----------------------------- INPUT ---------------------------------- st.header('Input') input_col_1, input_col_2, input_col_3 = st.columns(3) # --------------------------- INPUT column 1 --------------------------- with input_col_1: with st.form('text_form'): # image_id = st.number_input('Image ID: ', format='%d', step=1) st.write('**Choose a series of concepts to compare**') # chosen_text_id_input = st.empty() # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id) concept_ids = st.multiselect('Concept:', tuple(labels)) choose_text_button = st.form_submit_button('Choose the defined concepts') # random_text = st.form_submit_button('Select a random concept') # if random_text: # concept_id = random.choice(labels) # st.session_state.concept_id = concept_id # chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id) if choose_text_button: st.session_state.concept_ids = list(concept_ids) # st.write(image_id, st.session_state.image_id) # ---------------------------- SET UP OUTPUT ------------------------------ epsilon_container = st.empty() st.header('Output') st.subheader('Concept vector') # perform attack container # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1]) # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1]) header_col_1, header_col_2 = st.columns([5,1]) output_col_1, output_col_2 = st.columns([5,1]) st.subheader('Derivations along the concept vector') # prediction error container error_container = st.empty() smoothgrad_header_container = st.empty() # smoothgrad container smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1]) smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1]) # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------ with output_col_1: vectors, nodes_in_common = get_concepts_vectors(concept_ids, annotations, ann_df) # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence') #st.write('Concept vector', separation_vector) header_col_1.write(f'Concepts {", ".join(concept_ids)} - Relevant nodes in common: {nodes_in_common}')# - Nodes {",".join(list(imp_nodes))}') edges = [] for i in range(len(concepts)): for j in range(len(concepts)): if i != j: print(f'Similarity between {concepts[i]} and {concepts[j]}') similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1)) print(np.round(similarity[0][0], 3)) edges.append((concepts[i], concepts[j], np.round(similarity[0][0], 3))) # Create an empty graph G = nx.Graph() # Add edges with weights to the graph for edge in edges: node1, node2, weight = edge G.add_edge(node1, node2, weight=weight) # Initiate PyVis network object net = Network( height='400px', width='100%', bgcolor='#222222', font_color='white' ) # Take Networkx graph and translate it to a PyVis graph format net.from_nx(G) # Generate network with specific layout settings net.repulsion( node_distance=420, central_gravity=0.33, spring_length=110, spring_strength=0.10, damping=0.95 ) # Save and read graph as HTML file (on Streamlit Sharing) try: path = '/tmp' net.save_graph(f'{path}/pyvis_graph.html') HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8') # Save and read graph as HTML file (locally) except: path = '/html_files' net.save_graph(f'{path}/pyvis_graph.html') HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8') # Load HTML file in HTML component for display on Streamlit page components.html(HtmlFile.read(), height=435) # ----------------------------- INPUT column 2 & 3 ---------------------------- # with input_col_2: # with st.form('image_form'): # # image_id = st.number_input('Image ID: ', format='%d', step=1) # st.write('**Choose or generate a random image to test the disentanglement**') # chosen_image_id_input = st.empty() # image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id) # choose_image_button = st.form_submit_button('Choose the defined image') # random_id = st.form_submit_button('Generate a random image') # if random_id: # image_id = random.randint(0, 100000) # st.session_state.image_id = image_id # chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id) # if choose_image_button: # image_id = int(image_id) # st.session_state.image_id = int(image_id) # # st.write(image_id, st.session_state.image_id) # with input_col_3: # with st.form('Variate along the disentangled concept'): # st.write('**Set range of change**') # chosen_epsilon_input = st.empty() # epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=1, step=1) # epsilon_button = st.form_submit_button('Choose the defined epsilon') # # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------ # #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu')) # with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f: # model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore # original_image_vec = annotations['z_vectors'][st.session_state.image_id] # img = generate_original_image(original_image_vec, model) # # input_image = original_image_dict['image'] # # input_label = original_image_dict['label'] # # input_id = original_image_dict['id'] # with smoothgrad_col_3: # st.image(img) # smooth_head_3.write(f'Base image') # images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon)) # with smoothgrad_col_1: # st.image(images[0]) # smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}') # with smoothgrad_col_2: # st.image(images[1]) # smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}') # with smoothgrad_col_4: # st.image(images[3]) # smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}') # with smoothgrad_col_5: # st.image(images[4]) # smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')