webplip / viz_scripts /calc_img.py
huangzhii
new features added
9059a42
raw
history blame
2.06 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 10 21:13:04 2023
@author: zhihuang
"""
import pickle
import os
import pandas as pd
import numpy as np
import umap
import seaborn as sns
import matplotlib.pyplot as plt
opj=os.path.join
if __name__ == '__main__':
dd = '/home/zhihuang/Desktop/webplip/data'
with open(opj(dd, 'twitter.asset'),'rb') as f:
data = pickle.load(f)
n_neighbors = 15
random_state = 0
reducer = umap.UMAP(n_components=2,
n_neighbors=n_neighbors,
min_dist=0.1,
metric='euclidean',
random_state=random_state)
img_2d = reducer.fit(data['image_embedding'])
img_2d = reducer.transform(data['image_embedding'])
df_img = pd.DataFrame(np.c_[img_2d, data['meta'].values], columns = ['UMAP_1','UMAP_2'] + list(data['meta'].columns))
df_img.to_csv(opj(dd, 'img_2d_embedding.csv'))
# reducer = umap.UMAP(n_components=2,
# n_neighbors=n_neighbors,
# min_dist=0.1,
# metric='euclidean',
# random_state=random_state)
txt_2d = reducer.fit_transform(data['text_embedding'])
df_txt = pd.DataFrame(np.c_[txt_2d, data['meta'].values], columns = ['UMAP_1','UMAP_2'] + list(data['meta'].columns))
df_txt.to_csv(opj(dd, 'txt_2d_embedding.csv'))
fig, ax = plt.subplots(1,2, figsize=(20,10))
sns.scatterplot(data=df_img,
x='UMAP_1',
y='UMAP_2',
alpha=0.2,
ax=ax[0],
hue='tag'
)
sns.scatterplot(data=df_txt,
x='UMAP_1',
y='UMAP_2',
alpha=0.2,
ax=ax[1],
hue='tag'
)