|
import json |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.figure_factory as ff |
|
import plotly.graph_objects as go |
|
import streamlit as st |
|
from plotly.subplots import make_subplots |
|
|
|
from exp_utils import MODELS |
|
from visualize_utils import viridis_rgb |
|
|
|
st.set_page_config( |
|
page_title="Results Viewer", |
|
page_icon="📊", |
|
initial_sidebar_state="expanded", |
|
layout="wide", |
|
) |
|
|
|
MODELS_SIZE_MAPPING = {k: v["model_size"] for k, v in MODELS.items()} |
|
MODELS_FAMILY_MAPPING = {k: v["model_family"] for k, v in MODELS.items()} |
|
MODEL_FAMILES = set([model["model_family"] for model in MODELS.values()]) |
|
Q_W_MODELS = [ |
|
"llama-7b", |
|
"llama-2-7b", |
|
"llama-13b", |
|
"llama-2-13b", |
|
"llama-30b", |
|
"llama-65b", |
|
"llama-2-70b", |
|
] |
|
Q_W_MODELS = [f"{model}_quantized" for model in Q_W_MODELS] + [ |
|
f"{model}_watermarked" for model in Q_W_MODELS |
|
] |
|
|
|
MODEL_NAMES = list(MODELS.keys()) + Q_W_MODELS |
|
|
|
MODEL_NAMES_SORTED_BY_NAME_AND_SIZE = sorted( |
|
MODEL_NAMES, |
|
key=lambda x: ( |
|
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_family"], |
|
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_size"], |
|
), |
|
) |
|
|
|
MODEL_NAMES_SORTED_BY_SIZE = sorted( |
|
MODEL_NAMES, |
|
key=lambda x: ( |
|
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_size"], |
|
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_family"], |
|
), |
|
) |
|
|
|
|
|
|
|
MODELS_SIZE_MAPPING = { |
|
k: v |
|
for k, v in sorted(MODELS_SIZE_MAPPING.items(), key=lambda item: (item[1], item[0])) |
|
} |
|
|
|
MODELS_SIZE_MAPPING_LIST = list(MODELS_SIZE_MAPPING.keys()) |
|
|
|
|
|
CHAT_MODELS = [ |
|
x |
|
for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE |
|
if MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["is_chat"] |
|
] |
|
|
|
|
|
def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame: |
|
|
|
words_to_remove = [ |
|
"epoch", |
|
"loss", |
|
"runtime", |
|
"samples_per_second", |
|
"steps_per_second", |
|
"samples", |
|
"results_dir", |
|
] |
|
df = df.loc[ |
|
:, |
|
~df.columns.str.contains("|".join(words_to_remove), case=False, regex=True), |
|
] |
|
|
|
|
|
df.columns = df.columns.str.replace("_roc_auc", "") |
|
df.columns = df.columns.str.replace("eval_", "") |
|
|
|
df["model_family"] = df["model_name"].apply( |
|
lambda x: MODELS_FAMILY_MAPPING[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
] |
|
) |
|
|
|
model_family_dict = { |
|
k: v |
|
for k, v in zip( |
|
df["model_name"].values.tolist(), df["model_family"].values.tolist() |
|
) |
|
} |
|
|
|
|
|
df_avg = df.groupby(["model_name"]).mean() |
|
df_std = df.groupby(["model_name"]).std() |
|
|
|
|
|
df_avg = df_avg.drop(columns=["exp_seed"]) |
|
df_std = df_std.drop(columns=["exp_seed"]) |
|
df_avg["model_family"] = df_avg.index.map(model_family_dict) |
|
df_std["model_family"] = df_std.index.map(model_family_dict) |
|
df_avg["model_size"] = df_avg.index.map( |
|
lambda x: MODELS_SIZE_MAPPING[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
] |
|
) |
|
df_std["model_size"] = df_std.index.map( |
|
lambda x: MODELS_SIZE_MAPPING[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
] |
|
) |
|
|
|
|
|
df_avg = df_avg.sort_values( |
|
by=["model_family", "model_size"], ascending=[True, True] |
|
) |
|
df_std = df_std.sort_values( |
|
by=["model_family", "model_size"], ascending=[True, True] |
|
) |
|
|
|
availables_rows = [x for x in df_avg.columns if x in df_avg.index] |
|
df_avg = df_avg.reindex(availables_rows) |
|
|
|
availables_rows = [x for x in df_std.columns if x in df_std.index] |
|
df_std = df_std.reindex(availables_rows) |
|
|
|
df_avg["is_quantized"] = df_avg.index.str.contains("quantized") |
|
df_avg["is_watermarked"] = df_avg.index.str.contains("watermarked") |
|
df_std["is_quantized"] = df_std.index.str.contains("quantized") |
|
df_std["is_watermarked"] = df_std.index.str.contains("watermarked") |
|
|
|
return df_avg, df_std |
|
|
|
|
|
def get_data(path) -> Tuple[pd.DataFrame, pd.DataFrame]: |
|
df, df_std = clean_dataframe(pd.read_csv(path, index_col=0)) |
|
return df, df_std |
|
|
|
|
|
def filter_df( |
|
df: pd.DataFrame, |
|
model_family_train: list, |
|
model_family_test: list, |
|
model_size_train: tuple, |
|
model_size_test: tuple, |
|
is_chat_train: bool, |
|
is_chat_test: bool, |
|
is_quantized_train: bool, |
|
is_quantized_test: bool, |
|
is_watermarked_train: bool, |
|
is_watermarked_test: bool, |
|
sort_by_size: bool, |
|
split_chat_models: bool, |
|
split_quantized_models: bool, |
|
split_watermarked_models: bool, |
|
filter_empty_col_row: bool, |
|
is_debug: bool, |
|
) -> pd.DataFrame: |
|
|
|
|
|
|
|
if is_debug: |
|
st.write("No filters") |
|
st.write(df) |
|
df = df.loc[ |
|
(df["model_size"] >= model_size_train[0] * 1e9) |
|
& (df["model_size"] <= model_size_train[1] * 1e9) |
|
] |
|
if is_debug: |
|
st.write("Filter model size train") |
|
st.write(df) |
|
df = df.loc[df["model_family"].isin(model_family_train)] |
|
if is_debug: |
|
st.write("Filter model family train") |
|
st.write(df) |
|
if is_chat_train != "Both": |
|
df = df.loc[df["is_chat"] == is_chat_train] |
|
if is_debug: |
|
st.write("Filter is chat train") |
|
st.write(df) |
|
if is_quantized_train != "Both": |
|
df = df.loc[df["is_quantized"] == is_quantized_train] |
|
if is_debug: |
|
st.write("Filter is quantized train") |
|
st.write(df) |
|
if is_watermarked_train != "Both": |
|
df = df.loc[df["is_watermarked"] == is_watermarked_train] |
|
if is_debug: |
|
st.write("Filter is watermark train") |
|
st.write(df) |
|
|
|
|
|
if is_debug: |
|
st.write("No filters") |
|
st.write(df) |
|
columns_to_keep = [] |
|
for column in df.columns: |
|
if ( |
|
column.replace("_quantized", "").replace("_watermarked", "") |
|
in MODELS.keys() |
|
): |
|
model_size = MODELS[ |
|
column.replace("_quantized", "").replace("_watermarked", "") |
|
]["model_size"] |
|
if ( |
|
model_size >= model_size_test[0] * 1e9 |
|
and model_size <= model_size_test[1] * 1e9 |
|
): |
|
columns_to_keep.append(column) |
|
|
|
df = df[list(sorted(list(set(columns_to_keep))))] |
|
if is_debug: |
|
st.write("Filter model size test") |
|
st.write(df) |
|
|
|
|
|
columns_to_keep = [] |
|
for column in df.columns: |
|
for model_family in model_family_test: |
|
if ( |
|
model_family |
|
== MODELS[column.replace("_quantized", "").replace("_watermarked", "")][ |
|
"model_family" |
|
] |
|
): |
|
columns_to_keep.append(column) |
|
df = df[list(sorted(list(set(columns_to_keep))))] |
|
if is_debug: |
|
st.write("Filter model family test") |
|
st.write(df) |
|
|
|
if is_chat_test != "Both": |
|
|
|
columns_to_keep = [] |
|
for column in df.columns: |
|
if ( |
|
MODELS[column.replace("_quantized", "").replace("_watermarked", "")][ |
|
"is_chat" |
|
] |
|
== is_chat_test |
|
): |
|
columns_to_keep.append(column) |
|
df = df[list(sorted(list(set(columns_to_keep))))] |
|
if is_debug: |
|
st.write("Filter is chat test") |
|
st.write(df) |
|
|
|
if is_quantized_test != "Both": |
|
|
|
columns_to_keep = [] |
|
for column in df.columns: |
|
if "quantized" in column and is_quantized_test: |
|
columns_to_keep.append(column) |
|
elif "quantized" not in column and not is_quantized_test: |
|
columns_to_keep.append(column) |
|
df = df[list(sorted(list(set(columns_to_keep))))] |
|
if is_debug: |
|
st.write("Filter is quantized test") |
|
st.write(df) |
|
|
|
if is_watermarked_test != "Both": |
|
|
|
columns_to_keep = [] |
|
for column in df.columns: |
|
if "watermark" in column and is_watermarked_test: |
|
columns_to_keep.append(column) |
|
elif "watermark" not in column and not is_watermarked_test: |
|
columns_to_keep.append(column) |
|
df = df[list(sorted(list(set(columns_to_keep))))] |
|
if is_debug: |
|
st.write("Filter is watermark test") |
|
st.write(df) |
|
|
|
df = df.select_dtypes(include="number") |
|
if is_debug: |
|
st.write("Select dtypes to be only numbers") |
|
st.write(df) |
|
|
|
if sort_by_size: |
|
columns_in = [x for x in MODEL_NAMES_SORTED_BY_SIZE if x in df.columns] |
|
else: |
|
columns_in = [x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if x in df.columns] |
|
df = df[columns_in] |
|
if is_debug: |
|
st.write("Sort columns") |
|
st.write(df) |
|
|
|
|
|
if sort_by_size: |
|
availables_rows = [x for x in MODEL_NAMES_SORTED_BY_SIZE if x in df.index] |
|
df = df.reindex(availables_rows) |
|
else: |
|
availables_rows = [ |
|
x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if x in df.index |
|
] |
|
df = df.reindex(availables_rows) |
|
if is_debug: |
|
st.write("Sort rows") |
|
st.write(df) |
|
|
|
if split_chat_models: |
|
|
|
chat_models = [x for x in CHAT_MODELS if x in df.columns] |
|
|
|
chat_models = sorted(chat_models, key=lambda x: MODELS[x]["model_size"]) |
|
df = df[[x for x in df.columns if x not in chat_models] + chat_models] |
|
|
|
|
|
chat_models = [x for x in CHAT_MODELS if x in df.index] |
|
|
|
chat_models = sorted(chat_models, key=lambda x: MODELS[x]["model_size"]) |
|
df = df.reindex([x for x in df.index if x not in chat_models] + chat_models) |
|
if is_debug: |
|
st.write("Split chat models") |
|
st.write(df) |
|
|
|
if split_quantized_models: |
|
|
|
quantized_models = [ |
|
x for x in Q_W_MODELS if x in df.columns and "quantized" in x |
|
] |
|
|
|
quantized_models = sorted( |
|
quantized_models, |
|
key=lambda x: MODELS[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
]["model_size"], |
|
) |
|
df = df[[x for x in df.columns if x not in quantized_models] + quantized_models] |
|
|
|
|
|
quantized_models = [x for x in Q_W_MODELS if x in df.index and "quantized" in x] |
|
|
|
quantized_models = sorted( |
|
quantized_models, |
|
key=lambda x: MODELS[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
]["model_size"], |
|
) |
|
df = df.reindex( |
|
[x for x in df.index if x not in quantized_models] + quantized_models |
|
) |
|
|
|
if split_watermarked_models: |
|
|
|
watermarked_models = [ |
|
x for x in Q_W_MODELS if x in df.columns and "watermarked" in x |
|
] |
|
|
|
watermarked_models = sorted( |
|
watermarked_models, |
|
key=lambda x: MODELS[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
]["model_size"], |
|
) |
|
df = df[ |
|
[x for x in df.columns if x not in watermarked_models] + watermarked_models |
|
] |
|
|
|
|
|
watermarked_models = [ |
|
x for x in Q_W_MODELS if x in df.index and "watermarked" in x |
|
] |
|
|
|
watermarked_models = sorted( |
|
watermarked_models, |
|
key=lambda x: MODELS[ |
|
x.replace("_quantized", "").replace("_watermarked", "") |
|
]["model_size"], |
|
) |
|
df = df.reindex( |
|
[x for x in df.index if x not in watermarked_models] + watermarked_models |
|
) |
|
|
|
if is_debug: |
|
st.write("Split chat models") |
|
st.write(df) |
|
|
|
if filter_empty_col_row: |
|
|
|
df = df.dropna(axis=0, how="all") |
|
df = df.dropna(axis=1, how="all") |
|
return df |
|
|
|
|
|
df, df_std = get_data("./deberta_results.csv") |
|
df_q_w, df_std_q_w = get_data("./results_qantized_watermarked.csv") |
|
|
|
df = df.merge( |
|
df_q_w[ |
|
df_q_w.columns[ |
|
df_q_w.columns.str.contains("quantized|watermarked", case=False, regex=True) |
|
] |
|
], |
|
how="outer", |
|
left_index=True, |
|
right_index=True, |
|
) |
|
df_std = df_std.merge( |
|
df_std_q_w[ |
|
df_std_q_w.columns[ |
|
df_std_q_w.columns.str.contains( |
|
"quantized|watermarked", case=False, regex=True |
|
) |
|
] |
|
], |
|
how="outer", |
|
left_index=True, |
|
right_index=True, |
|
) |
|
|
|
|
|
df.columns = df.columns.str.replace("_y", "", regex=True) |
|
df_std.columns = df_std.columns.str.replace("_y", "", regex=True) |
|
|
|
df = df.drop(columns=["is_quantized_x", "is_watermarked_x"]) |
|
|
|
|
|
df.update(df_q_w) |
|
df_std.update(df_std_q_w) |
|
|
|
|
|
df["is_chat"].fillna(False, inplace=True) |
|
df_std["is_chat"].fillna(False, inplace=True) |
|
|
|
df["is_watermarked"].fillna(False, inplace=True) |
|
df_std["is_watermarked"].fillna(False, inplace=True) |
|
|
|
df["is_quantized"].fillna(False, inplace=True) |
|
df_std["is_quantized"].fillna(False, inplace=True) |
|
|
|
with open("./ood_results.json", "r") as f: |
|
ood_results = json.load(f) |
|
|
|
ood_results = pd.DataFrame(ood_results) |
|
ood_results = ood_results.set_index("model_name") |
|
ood_results = ood_results.drop( |
|
columns=["exp_name", "accuracy", "f1", "precision", "recall"] |
|
) |
|
ood_results.columns = ["seed", "Adversarial"] |
|
|
|
ood_results_avg = ood_results.groupby(["model_name"]).mean() |
|
ood_results_std = ood_results.groupby(["model_name"]).std() |
|
|
|
st.write( |
|
"""### Results Viewer 👇 |
|
|
|
## From Text to Source: Results in Detecting Large Language Model-Generated Content |
|
|
|
### Wissam Antoun, Benoît Sagot, Djamé Seddah |
|
##### ALMAnaCH, Inria |
|
|
|
##### Paper: [https://arxiv.org/abs/2309.13322](https://arxiv.org/abs/2309.13322) |
|
""" |
|
) |
|
|
|
|
|
show_diff = st.sidebar.checkbox("Show Diff", value=False) |
|
sort_by_size = st.sidebar.checkbox("Sort by size", value=True) |
|
split_chat_models = st.sidebar.checkbox("Split chat models", value=True) |
|
split_quantized_models = st.sidebar.checkbox("Split quantized models", value=True) |
|
split_watermarked_models = st.sidebar.checkbox("Split watermarked models", value=True) |
|
add_mean = st.sidebar.checkbox("Add mean", value=False) |
|
show_std = st.sidebar.checkbox("Show std", value=False) |
|
filter_empty_col_row = st.sidebar.checkbox("Filter empty col/row", value=True) |
|
model_size_train = st.sidebar.slider( |
|
"Train Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1 |
|
) |
|
model_size_test = st.sidebar.slider( |
|
"Test Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1 |
|
) |
|
is_chat_train = st.sidebar.selectbox("(Train) Is Chat?", [True, False, "Both"], index=2) |
|
is_chat_test = st.sidebar.selectbox("(Test) Is Chat?", [True, False, "Both"], index=2) |
|
is_quantized_train = st.sidebar.selectbox( |
|
"(Train) Is Quantized?", [True, False, "Both"], index=1 |
|
) |
|
is_quantized_test = st.sidebar.selectbox( |
|
"(Test) Is Quantized?", [True, False, "Both"], index=1 |
|
) |
|
is_watermarked_train = st.sidebar.selectbox( |
|
"(Train) Is Watermark?", [True, False, "Both"], index=1 |
|
) |
|
is_watermarked_test = st.sidebar.selectbox( |
|
"(Test) Is Watermark?", [True, False, "Both"], index=1 |
|
) |
|
model_family_train = st.sidebar.multiselect( |
|
"Model Family Train", |
|
MODEL_FAMILES, |
|
default=MODEL_FAMILES, |
|
) |
|
model_family_test = st.sidebar.multiselect( |
|
"Model Family Test", |
|
list(MODEL_FAMILES) + ["Adversarial"], |
|
default=MODEL_FAMILES, |
|
) |
|
|
|
show_values = st.sidebar.checkbox("Show Values", value=False) |
|
|
|
add_adversarial = False |
|
if "Adversarial" in model_family_test: |
|
model_family_test.remove("Adversarial") |
|
add_adversarial = True |
|
|
|
sort_by_adversarial = False |
|
if add_adversarial: |
|
sort_by_adversarial = st.sidebar.checkbox("Sort by adversarial", value=False) |
|
|
|
if st.sidebar.checkbox("Use default color scale", value=False): |
|
color_scale = "Viridis_r" |
|
else: |
|
color_scale = viridis_rgb |
|
|
|
|
|
is_debug = st.sidebar.checkbox("Debug", value=False) |
|
|
|
if show_std: |
|
selected_df = df_std.copy() |
|
else: |
|
selected_df = df.copy() |
|
|
|
|
|
filtered_df = filter_df( |
|
selected_df, |
|
model_family_train, |
|
model_family_test, |
|
model_size_train, |
|
model_size_test, |
|
is_chat_train, |
|
is_chat_test, |
|
is_quantized_train, |
|
is_quantized_test, |
|
is_watermarked_train, |
|
is_watermarked_test, |
|
sort_by_size, |
|
split_chat_models, |
|
split_quantized_models, |
|
split_watermarked_models, |
|
filter_empty_col_row, |
|
is_debug, |
|
) |
|
|
|
|
|
if show_diff: |
|
|
|
diag = filtered_df.values.diagonal() |
|
filtered_df = filtered_df.sub(diag, axis=1) |
|
|
|
|
|
if add_adversarial: |
|
if show_diff: |
|
index = filtered_df.index |
|
ood_results_avg = ood_results_avg.loc[index] |
|
filtered_df = filtered_df.join(ood_results_avg.sub(diag, axis=0)) |
|
else: |
|
filtered_df = filtered_df.join(ood_results_avg) |
|
|
|
if add_mean: |
|
col_mean = filtered_df.mean(axis=1) |
|
row_mean = filtered_df.mean(axis=0) |
|
diag = filtered_df.values.diagonal() |
|
filtered_df["mean"] = col_mean |
|
filtered_df.loc["mean"] = row_mean |
|
|
|
filtered_df = filtered_df * 100 |
|
filtered_df = filtered_df.round(0) |
|
|
|
|
|
if sort_by_adversarial: |
|
filtered_df = filtered_df.sort_values(by=["Adversarial"], ascending=False) |
|
|
|
|
|
if filtered_df.shape[0] == 0: |
|
st.write("No results found") |
|
st.stop() |
|
|
|
if filtered_df.shape[1] == 0: |
|
st.write("No results found") |
|
st.stop() |
|
|
|
fig = px.imshow( |
|
filtered_df.values, |
|
x=list(filtered_df.columns), |
|
y=list(filtered_df.index), |
|
color_continuous_scale=color_scale, |
|
contrast_rescaling=None, |
|
text_auto=show_values, |
|
aspect="auto", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.update_traces(textfont_size=9) |
|
fig.update_layout( |
|
xaxis={"side": "top"}, |
|
yaxis={"side": "left"}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
font=dict(size=10), |
|
) |
|
fig.update_xaxes(tickangle=45) |
|
|
|
fig.update_xaxes(tickmode="linear") |
|
fig.update_yaxes(tickmode="linear") |
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if add_mean and not show_diff: |
|
|
|
if len([x for x in CHAT_MODELS if x in filtered_df.columns]) > 0 or len( |
|
[x for x in CHAT_MODELS if x in filtered_df.index] |
|
): |
|
st.warning( |
|
"Chat models are in the filtered df columns or index." |
|
"This will cause the mean graph to be skewed." |
|
) |
|
|
|
fig3 = px.scatter( |
|
y=row_mean, |
|
x=[MODELS[x]["model_size"] for x in filtered_df.columns if x not in ["mean"]], |
|
|
|
color=[ |
|
MODELS[x]["model_family"] for x in filtered_df.columns if x not in ["mean"] |
|
], |
|
color_discrete_sequence=px.colors.qualitative.Plotly, |
|
title="", |
|
|
|
labels={ |
|
"x": "Target Model Size", |
|
"y": "Average ROC AUC", |
|
"color": "Model Family", |
|
}, |
|
log_x=True, |
|
trendline="ols", |
|
) |
|
fig4 = px.scatter( |
|
y=diag, |
|
x=[MODELS[x]["model_size"] for x in filtered_df.columns if x not in ["mean"]], |
|
|
|
color=[ |
|
MODELS[x]["model_family"] for x in filtered_df.columns if x not in ["mean"] |
|
], |
|
color_discrete_sequence=px.colors.qualitative.Plotly, |
|
title="", |
|
|
|
labels={ |
|
"x": "Target Model Size", |
|
"y": "Self ROC AUC", |
|
"color": "Model Family", |
|
}, |
|
log_x=True, |
|
trendline="ols", |
|
) |
|
|
|
|
|
fig_subplot = make_subplots( |
|
rows=1, |
|
cols=2, |
|
shared_yaxes=False, |
|
subplot_titles=("Self Detection ROC AUC", "Average Target ROC AUC"), |
|
) |
|
for i, figure in enumerate([fig4, fig3]): |
|
for trace in range(len(figure["data"])): |
|
trace_data = figure["data"][trace] |
|
if i == 1: |
|
trace_data["showlegend"] = False |
|
fig_subplot.append_trace(trace_data, row=1, col=i + 1) |
|
|
|
fig_subplot.update_xaxes(type="log") |
|
|
|
fig_subplot.update_yaxes(range=[0.90, 1]) |
|
|
|
fig_subplot.update_layout( |
|
height=500, |
|
width=1200, |
|
) |
|
|
|
fig_subplot.update_layout( |
|
legend=dict(orientation="h", yanchor="bottom", y=-0.2, x=0.09) |
|
) |
|
st.plotly_chart(fig_subplot, use_container_width=True) |
|
|
|
fig2 = px.scatter( |
|
y=col_mean, |
|
x=[MODELS_SIZE_MAPPING[x] for x in filtered_df.index if x not in ["mean"]], |
|
|
|
color=[ |
|
MODELS_FAMILY_MAPPING[x] for x in filtered_df.index if x not in ["mean"] |
|
], |
|
color_discrete_sequence=px.colors.qualitative.Plotly, |
|
title="Mean vs Train Model Size", |
|
log_x=True, |
|
trendline="ols", |
|
) |
|
fig2.update_layout( |
|
height=600, |
|
width=900, |
|
) |
|
st.plotly_chart(fig2, use_container_width=False) |
|
|