IDEFICS_Data_Measurement_Tool / test_text_len.py
Ezi's picture
Upload 312 files
46df0b6
raw
history blame
9.13 kB
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import gradio as gr
from os.path import isdir
from data_measurements.dataset_statistics import DatasetStatisticsCacheClass as dmt_cls
import utils
from utils import dataset_utils
from utils import gradio_utils as gr_utils
import widgets
import app as ap
from app import load_or_prepare_widgets
logs = utils.prepare_logging(__file__)
# Utility for sidebar description and selection of the dataset
DATASET_NAME_TO_DICT = dataset_utils.get_dataset_info_dicts()
def get_load_prepare_list(dstats):
"""
# Get load_or_prepare functions for the measurements we will display
"""
# Measurement calculation:
# Add any additional modules and their load-prepare function here.
load_prepare_list = [
("text_lengths", dstats.load_or_prepare_text_lengths),
]
return load_prepare_list
def get_ui_widgets():
"""Get the widgets that will be displayed in the UI."""
return [
widgets.TextLengths(),]
def get_widgets():
"""
# A measurement widget requires 2 things:
# - A load or prepare function
# - A display function
# We define these in two separate functions get_load_prepare_list and get_ui_widgets;
# any widget can be added by modifying both functions and the rest of the app logic will work.
# get_load_prepare_list is a function since it requires a DatasetStatisticsCacheClass which will
# not be created until dataset and config values are selected in the ui
"""
return get_load_prepare_list, get_ui_widgets()
def get_title(dstats):
title_str = f"### Showing: {dstats.dset_name} - {dstats.dset_config} - {dstats.split_name} - {'-'.join(dstats.text_field)}"
logs.info("showing header")
return title_str
def display_initial_UI():
"""Displays the header in the UI"""
# Extract the selected arguments
dataset_args = gr_utils.sidebar_selection(DATASET_NAME_TO_DICT)
return dataset_args
def show_column(dstats, display_list, show_perplexities, column_id=""):
"""
Function for displaying the elements in the streamlit app.
Args:
dstats (class): The dataset_statistics.py DatasetStatisticsCacheClass
display_list (list): List of tuples for (widget_name, widget_display_function)
show_perplexities (Bool): Whether perplexities should be loaded and displayed for this dataset
column_id (str): Which column of the dataset the analysis is done on [DEPRECATED for v1]
"""
# start showing stuff
gr_utils.expander_header(dstats, DATASET_NAME_TO_DICT)
for widget_tuple in display_list:
widget_type = widget_tuple[0]
widget_fn = widget_tuple[1]
logs.info("showing %s." % widget_type)
try:
widget_fn(dstats, column_id)
except Exception as e:
logs.warning("Jk jk jk. There was an issue with %s:" % widget_type)
logs.exception(e)
# TODO: Fix how this is a weird outlier.
if show_perplexities:
gr_utils.expander_text_perplexities(dstats, column_id)
logs.info("Have finished displaying the widgets.")
def create_demo(live: bool, pull_cache_from_hub: bool):
with gr.Blocks() as demo:
state = gr.State()
with gr.Row():
with gr.Column(scale=1):
dataset_args = display_initial_UI()
get_load_prepare_list_fn, widget_list = get_widgets()
# # TODO: Make this less of a weird outlier.
# Doesn't do anything right now
show_perplexities = gr.Checkbox(label="Show text perplexities")
with gr.Column(scale=4):
gr.Markdown("# Data Measurements Tool")
title = gr.Markdown()
for widget in widget_list:
widget.render()
# when UI upates, call the new text --> parse to teh TTi function
def update_ui(dataset: str, config: str, split: str, feature: str):
feature = ast.literal_eval(feature)
label_field, label_names = gr_utils.get_label_names(dataset, config, DATASET_NAME_TO_DICT)
dstats = dmt_cls(dset_name=dataset, dset_config=config, split_name=split, text_field=feature,
label_field=label_field, label_names=label_names, use_cache=True)
load_prepare_list = get_load_prepare_list_fn(dstats)
dstats = load_or_prepare_widgets(dstats, load_prepare_list, show_perplexities=False,
live=live, pull_cache_from_hub=pull_cache_from_hub)
output = {title: get_title(dstats), state: dstats}
for widget in widget_list:
output.update(widget.update(dstats))
return output
def update_dataset(dataset: str):
new_values = gr_utils.update_dataset(dataset, DATASET_NAME_TO_DICT)
config = new_values[0][1]
feature = new_values[1][1]
split = new_values[2][1]
new_dropdown = {
dataset_args["text_field"]: gr.Dropdown.update(choices=new_values[1][0], value=feature),
dataset_args["split_name"]: gr.Dropdown.update(choices=new_values[2][0], value=split),
}
return new_dropdown
def update_config(dataset: str, config: str):
new_values = gr_utils.update_config(dataset, config, DATASET_NAME_TO_DICT)
feature = new_values[0][1]
split = new_values[1][1]
new_dropdown = {
dataset_args["text_field"]: gr.Dropdown.update(choices=new_values[0][0], value=feature),
dataset_args["split_name"]: gr.Dropdown.update(choices=new_values[1][0], value=split)
}
return new_dropdown
measurements = [comp for output in widget_list for comp in output.output_components]
demo.load(update_ui,
inputs=[dataset_args["dset_name"], dataset_args["dset_config"], dataset_args["split_name"], dataset_args["text_field"]],
outputs=[title, state] + measurements)
print(dataset_args["text_field"])
for widget in widget_list:
widget.add_events(state)
dataset_args["dset_name"].change(update_dataset,
inputs=[dataset_args["dset_name"]],
outputs=[dataset_args["dset_config"],
dataset_args["split_name"], dataset_args["text_field"],
title, state] + measurements)
dataset_args["dset_config"].change(update_config,
inputs=[dataset_args["dset_name"], dataset_args["dset_config"]],
outputs=[dataset_args["split_name"], dataset_args["text_field"],
title, state] + measurements)
dataset_args["calculate_btn"].click(update_ui,
inputs=[dataset_args["dset_name"], dataset_args["dset_config"],
dataset_args["split_name"], dataset_args["text_field"]],
outputs=[title, state] + measurements)
return demo
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--live", default=False, required=False, action="store_true", help="Flag to specify that this is not running live.")
parser.add_argument(
"--pull_cache_from_hub", default=False, required=False, action="store_true", help="Flag to specify whether to look in the hub for measurements caches. If you are using this option, you must have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo.")
arguments = parser.parse_args()
live = arguments.live
pull_cache_from_hub = arguments.pull_cache_from_hub
# Create and initialize the demo
dataset_args = display_initial_UI()
demo = create_demo(live, pull_cache_from_hub)
print("this is the cureenrt TEXT:")
print(dataset_args["text_field"])
demo.launch()
if __name__ == "__main__":
main()