from datasets import load_dataset import streamlit as st from data_utils import get_embedding from bokeh.plotting import figure,show from bokeh.io import push_notebook, output_notebook # output_notebook() from bokeh.palettes import d3 from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter from bokeh.transform import factor_cmap, factor_mark import base64 from io import BytesIO label_columns=["gender","subCategory","masterCategory"] model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model 'microsoft/beit-base-patch16-224', # big model "facebook/dino-vits8", "facebook/levit-128S"] def convert_base64(img): buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return "data:image/jpeg;base64,"+img_str @st.experimental_singleton def cache_embedding(model_name): dataset=load_dataset("ceyda/fashion-products-small", split="train") dataset=dataset.shuffle(seed=100) #pick a random seed viz_dat=dataset.train_test_split(0.1,shuffle=False) #일부를 visualization위해서 뽑시단 viz_dat=viz_dat["test"] embedding = get_embedding(model_name,viz_dat) embedding["image"]=embedding["image"].apply(convert_base64) labels = {label:viz_dat.unique(label) for label in label_columns} return embedding,labels @st.experimental_singleton def cache_graph(model_name,color_column): embedding,labels=cache_embedding(model_name) color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])] source = ColumnDataSource(data=embedding) # colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique()) TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select," TOOLTIPS = """