linjunyao
updated citation textbox lines
a5bf394
raw
history blame
12.9 kB
import abc
import gradio as gr
from loguru import logger
import pandas as pd
from collections import defaultdict
from judgerbench.preprocess.gen_table import (
format_timestamp,
generate_table,
build_l1_df,
# build_l2_df,
)
from judgerbench.meta_data import (
LEADERBORAD_INTRODUCTION,
LEADERBOARD_MD,
LEADERBOARD_FILE_MAPPING,
MAIN_FIELDS,
DEFAULT_BENCH,
STYLE_CLASS_MAPPING,
CITATION_BUTTON_TEXT,
CITATION_BUTTON_LABEL,
)
def refresh_dataframe(required_fields):
df = generate_table(MAIN_FIELDS)
comp = gr.DataFrame(
value=df,
type='pandas',
interactive=False,
visible=True
)
return comp
with gr.Blocks() as demo:
# struct = load_results()
# timestamp = struct['time']
# EVAL_TIME = format_timestamp(timestamp)
EVAL_TIME = '20241015'
# results = struct['results']
# N_MODEL = len(results)
# N_DATA = len(results['LLaVA-v1.5-7B']) - 1
N_MODEL = 10
N_DATA = 100
# DATASETS = list(results['LLaVA-v1.5-7B'])
# DATASETS.remove('META')
# print(DATASETS)
gr.Markdown(LEADERBORAD_INTRODUCTION.format(
# N_MODEL,
# N_DATA,
EVAL_TIME
))
# structs = [abc.abstractproperty() for _ in range(N_DATA)]
with gr.Tabs(elem_classes='tab-buttons') as tabs:
for cur_id, (filename, filepath) in enumerate(LEADERBOARD_FILE_MAPPING.items()):
tab_name = filename
# if filename == "overall":
# tab_name = 'OVERALL'
with gr.Tab(tab_name.upper(), elem_id=f'tab_{cur_id}', id=cur_id):
# gr.Markdown(LEADERBOARD_MD['MAIN'])
# _, check_box = build_l1_df(MAIN_FIELDS)
table = generate_table(filename=filename)
# type_map = check_box['type_map']
type_map = defaultdict(lambda: 'number')
type_map['Model'] = 'str'
type_map['Class'] = 'str'
type_map['Rank'] = 'number'
# required_fields = gr.State(
# check_box['essential']
# # + ["Average"]
# )
# checkbox_group = gr.CheckboxGroup(
# choices=[item for item in check_box['all'] if item not in required_fields.value],
# value=[item for item in check_box['default'] if item not in required_fields.value],
# label='Evaluation Metrics',
# interactive=True,
# )
# headers = (
# ['Rank'] +
# required_fields.value +
# [item for item in check_box['all'] if item not in required_fields.value]
# # checkbox_group.value
# )
table['Rank'] = list(range(1, len(table) + 1))
# Rearrange columns
if "Class" in table.columns:
starting_columns = ["Rank", "Models", "Class"]
else:
starting_columns = ["Rank", "Models"]
table = table[starting_columns + [ col for col in table.columns if col not in starting_columns ]]
headers = (
# ['Rank'] +
list(table.columns)
)
if "Class" in table.columns:
def cell_styler(v):
df = v.copy()
class_var = df[['Class']].copy()
df.loc[:, :] = ''
df[['Class']] = class_var.map(lambda x: f"background-color: {STYLE_CLASS_MAPPING[x]}")
logger.info(df['Class'])
return df
table_styler = (
table.style.apply(cell_styler, axis=None)
.format(precision=3)
)
else:
table_styler = table.style.format(prevision=3)
# with gr.Row():
# model_size = gr.CheckboxGroup(
# choices=MODEL_SIZE,
# value=MODEL_SIZE,
# label='Model Size',
# interactive=True
# )
# model_type = gr.CheckboxGroup(
# choices=MODEL_TYPE,
# value=MODEL_TYPE,
# label='Model Type',
# interactive=True
# )
data_component = gr.DataFrame(
value=table_styler,
type='pandas',
datatype=[type_map[x] for x in headers],
interactive=False,
visible=True
)
def filter_df(
required_fields,
fields,
# model_size,
# model_type
):
# filter_list = ['Avg Score', 'Avg Rank', 'OpenSource', 'Verified']
headers = ['Rank'] + required_fields + fields
# new_fields = [field for field in fields if field not in filter_list]
df = generate_table(fields)
logger.info(f"{df.columns=}")
# df['flag'] = [model_size_flag(x, model_size) for x in df['Param (B)']]
# df = df[df['flag']]
# df.pop('flag')
# if len(df):
# df['flag'] = [model_type_flag(df.iloc[i], model_type) for i in range(len(df))]
# df = df[df['flag']]
# df.pop('flag')
df['Rank'] = list(range(1, len(df) + 1))
comp = gr.DataFrame(
value=df[headers],
type='pandas',
datatype=[type_map[x] for x in headers],
interactive=False,
visible=True
)
return comp
# for cbox in [
# # checkbox_group,
# # model_size,
# # model_type
# ]:
# cbox.change(
# fn=refresh_dataframe,
# inputs=[required_fields],
# outputs=data_component
# ).then(
# fn=filter_df,
# inputs=[
# required_fields,
# checkbox_group,
# # model_size,
# # model_type
# ],
# outputs=data_component
# )
# with gr.Tab('🔍 About', elem_id='about', id=1):
# gr.Markdown(urlopen(VLMEVALKIT_README).read().decode())
# for i, dataset in enumerate(DATASETS):
# with gr.Tab(f'📊 {dataset} Leaderboard', elem_id=dataset, id=i + 2):
# if dataset in LEADERBOARD_MD:
# gr.Markdown(LEADERBOARD_MD[dataset])
# s = structs[i]
# s.table, s.check_box = build_l2_df(results, dataset)
# s.type_map = s.check_box['type_map']
# s.type_map['Rank'] = 'number'
# s.checkbox_group = gr.CheckboxGroup(
# choices=s.check_box['all'],
# value=s.check_box['required'],
# label=f'{dataset} CheckBoxes',
# interactive=True,
# )
# s.headers = ['Rank'] + s.check_box['essential'] + s.checkbox_group.value
# s.table['Rank'] = list(range(1, len(s.table) + 1))
# with gr.Row():
# s.model_size = gr.CheckboxGroup(
# choices=MODEL_SIZE,
# value=MODEL_SIZE,
# label='Model Size',
# interactive=True
# )
# s.model_type = gr.CheckboxGroup(
# choices=MODEL_TYPE,
# value=MODEL_TYPE,
# label='Model Type',
# interactive=True
# )
# s.data_component = gr.components.DataFrame(
# value=s.table[s.headers],
# type='pandas',
# datatype=[s.type_map[x] for x in s.headers],
# interactive=False,
# visible=True)
# s.dataset = gr.Textbox(value=dataset, label=dataset, visible=False)
# def filter_df_l2(dataset_name, fields, model_size, model_type):
# s = structs[DATASETS.index(dataset_name)]
# headers = ['Rank'] + s.check_box['essential'] + fields
# df = cp.deepcopy(s.table)
# df['flag'] = [model_size_flag(x, model_size) for x in df['Param (B)']]
# df = df[df['flag']]
# df.pop('flag')
# if len(df):
# df['flag'] = [model_type_flag(df.iloc[i], model_type) for i in range(len(df))]
# df = df[df['flag']]
# df.pop('flag')
# df['Rank'] = list(range(1, len(df) + 1))
# comp = gr.components.DataFrame(
# value=df[headers],
# type='pandas',
# datatype=[s.type_map[x] for x in headers],
# interactive=False,
# visible=True)
# return comp
# for cbox in [s.checkbox_group, s.model_size, s.model_type]:
# cbox.change(
# fn=filter_df_l2,
# inputs=[s.dataset, s.checkbox_group, s.model_size, s.model_type],
# outputs=s.data_component)
with gr.Row():
with gr.Accordion('Citation', open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id='citation-button',
lines=10,
)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default="7860")
parser.add_argument(
"--share",
action="store_true",
help="Whether to generate a public, shareable link",
)
parser.add_argument(
"--concurrency-count",
type=int,
default=10,
help="The concurrency count of the gradio queue",
)
parser.add_argument(
"--max-threads",
type=int,
default=200,
help="The maximum number of threads available to process non-async functions.",
)
# parser.add_argument(
# "--gradio-auth-path",
# type=str,
# help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
# default=None,
# )
parser.add_argument(
"--gradio-root-path",
type=str,
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
)
parser.add_argument(
"--ga-id",
type=str,
help="the Google Analytics ID",
default=None,
)
parser.add_argument(
"--use-remote-storage",
action="store_true",
default=False,
help="Uploads image files to google cloud storage if set to true",
)
args = parser.parse_args()
logger.info(f"args: {args}")
# Set authorization credentials
# auth = None
# if args.gradio_auth_path is not None:
# auth = parse_gradio_auth_creds(args.gradio_auth_path)
demo.queue(
default_concurrency_limit=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=args.max_threads,
# auth=auth,
root_path=args.gradio_root_path,
# debug=True,
show_error=True,
allowed_paths=["../.."]
)