Spaces:
Runtime error
Runtime error
import logging | |
import time | |
from typing import Callable, Optional, Tuple, Union | |
import pandas as pd | |
import streamlit as st | |
from bokeh.palettes import Turbo256 | |
from bokeh.plotting import Figure | |
from embedding_lenses.embedding import embed_text | |
from embedding_lenses.utils import encode_labels | |
from embedding_lenses.visualization import draw_interactive_scatter_plot | |
from sentence_transformers import SentenceTransformer | |
from perplexity_lenses import REGISTRY_DATASET | |
logger = logging.getLogger(__name__) | |
EMBEDDING_MODELS = [ | |
"distiluse-base-multilingual-cased-v1", | |
"distiluse-base-multilingual-cased-v2", | |
"all-mpnet-base-v2", | |
"flax-sentence-embeddings/all_datasets_v3_mpnet-base", | |
] | |
DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"] | |
DOCUMENT_TYPES = ["Whole document", "Sentence"] | |
SEED = 0 | |
LANGUAGES = [ | |
"af", | |
"ar", | |
"az", | |
"be", | |
"bg", | |
"bn", | |
"ca", | |
"cs", | |
"da", | |
"de", | |
"el", | |
"en", | |
"es", | |
"et", | |
"fa", | |
"fi", | |
"fr", | |
"gu", | |
"he", | |
"hi", | |
"hr", | |
"hu", | |
"hy", | |
"id", | |
"is", | |
"it", | |
"ja", | |
"ka", | |
"kk", | |
"km", | |
"kn", | |
"ko", | |
"lt", | |
"lv", | |
"mk", | |
"ml", | |
"mn", | |
"mr", | |
"my", | |
"ne", | |
"nl", | |
"no", | |
"pl", | |
"pt", | |
"ro", | |
"ru", | |
"uk", | |
"zh", | |
] | |
PERPLEXITY_MODELS = ["Wikipedia", "OSCAR"] | |
class ContextLogger: | |
def __init__(self, text: str = ""): | |
self.text = text | |
self.start_time = time.time() | |
def __enter__(self): | |
logger.info(self.text) | |
def __exit__(self, type, value, traceback): | |
logger.info(f"Took: {time.time() - self.start_time:.4f} seconds") | |
def generate_plot( | |
df: pd.DataFrame, | |
text_column: str, | |
label_column: str, | |
sample: Optional[int], | |
dimensionality_reduction_function: Callable, | |
model: SentenceTransformer, | |
seed: int = 0, | |
context_logger: Union[st.spinner, ContextLogger] = ContextLogger, | |
hub_dataset: str = "", | |
) -> Tuple[Figure, Optional[Figure]]: | |
if text_column not in df.columns: | |
raise ValueError( | |
f"The specified column name doesn't exist. Columns available: {df.columns.values}" | |
) | |
if label_column not in df.columns: | |
df[label_column] = 0 | |
df = df.dropna(subset=[text_column, label_column]) | |
if sample: | |
df = df.sample(min(sample, df.shape[0]), random_state=seed) | |
with context_logger(text="Embedding text..."): | |
embeddings = embed_text(df[text_column].values.tolist(), model) | |
logger.info("Encoding labels") | |
encoded_labels = encode_labels(df[label_column]) | |
with context_logger("Reducing dimensionality..."): | |
embeddings_2d = dimensionality_reduction_function(embeddings) | |
logger.info("Generating figure") | |
hover_data = { | |
text_column: df[text_column].values, | |
label_column: encoded_labels.values, | |
} | |
# Round perplexity values | |
values = df[label_column].values.round().astype(int) | |
plot = draw_interactive_scatter_plot( | |
hover_data, | |
embeddings_2d[:, 0], | |
embeddings_2d[:, 1], | |
values, | |
) | |
# Special case for the registry dataset | |
plot_registry = None | |
if hub_dataset == REGISTRY_DATASET: | |
encoded_labels = encode_labels(df["label"]) | |
hover_data = { | |
text_column: df[text_column].values, | |
"label": df["label"].values, | |
label_column: df[label_column].values, | |
} | |
plot_registry = draw_interactive_scatter_plot( | |
hover_data, | |
embeddings_2d[:, 0], | |
embeddings_2d[:, 1], | |
encoded_labels.values, | |
palette=Turbo256, | |
) | |
return plot, plot_registry | |