import multiprocessing import textwrap import time import traceback from typing import TypedDict import numpy as np import pandas as pd from rdflib import term import streamlit as st from components.safe_button import button_with_confirmation from core.constants import NAMES_INFO from core.data_types import MLC_DATA_TYPES from core.data_types import mlc_to_str_data_type from core.data_types import STR_DATA_TYPES from core.data_types import str_to_mlc_data_type from core.query_params import expand_record_set from core.query_params import is_record_set_expanded from core.state import Field from core.state import Metadata from core.state import RecordSet from core.state import SelectedRecordSet from events.record_sets import handle_record_set_change from events.record_sets import RecordSetEvent import mlcroissant as mlc from utils import needed_field from views.source import FieldEvent from views.source import handle_field_change from views.source import render_references from views.source import render_source _NUM_RECORDS = 3 _TIMEOUT_SECONDS = 1 _INFO = """RecordSets describe sets of structured records obtained from resources or other RecordSets. You can think of RecordSets as tables with typed fields.""" class _Result(TypedDict): df: pd.DataFrame | None exception: Exception | None @st.cache_data( show_spinner="Generating the dataset...", hash_funcs={ "mlcroissant.Metadata": hash, "mlcroissant.Field": hash, "mlcroissant.FileObject": hash, "mlcroissant.FileSet": hash, "mlcroissant.RecordSet": hash, }, ) def _generate_data_with_timeout(record_set: RecordSet) -> _Result: """Generates the data and waits at most _TIMEOUT_SECONDS.""" with multiprocessing.Manager() as manager: result: _Result = manager.dict(df=None, exception=None) args = (record_set, result) process = multiprocessing.Process(target=_generate_data, args=args) process.start() if not process.is_alive(): return _Result(**result) time.sleep(_TIMEOUT_SECONDS) if process.is_alive(): process.kill() result["exception"] = TimeoutError( "The generation took too long and was killed. Please, use the CLI as" " described in" " https://github.com/mlcommons/croissant/tree/main/python/mlcroissant#verifyload-a-croissant-dataset." ) return _Result(**result) def _generate_data(record_set: RecordSet, result: _Result) -> pd.DataFrame | None: """Generates the first _NUM_RECORDS records.""" try: metadata: Metadata = st.session_state[Metadata] if metadata is None: raise ValueError( "The dataset is still incomplete. Please, go to the overview to see" " errors." ) croissant = metadata.to_canonical() if croissant: dataset = mlc.Dataset.from_metadata(croissant) records = iter(dataset.records(record_set=record_set.name)) df = [] for i, record in enumerate(iter(records)): if i >= _NUM_RECORDS: break # Decode bytes as str: for key, value in record.items(): if isinstance(value, bytes): try: record[key] = value.decode("utf-8") except: pass df.append(record) result["df"] = pd.DataFrame(df) except Exception: result["exception"] = traceback.format_exc() def _handle_close_fields(): st.session_state[SelectedRecordSet] = None def _handle_on_click_field( record_set_key: int, record_set: RecordSet, ): st.session_state[SelectedRecordSet] = SelectedRecordSet( record_set_key=record_set_key, record_set=record_set, ) def _data_editor_key(record_set_key: int, record_set: RecordSet) -> str: return f"{record_set_key}-{record_set.name}-dataframe" def _get_possible_sources(metadata: Metadata) -> list[str]: possible_sources: list[str] = [] for resource in metadata.distribution: possible_sources.append(resource.name) for record_set in metadata.record_sets: for field in record_set.fields: possible_sources.append(f"{record_set.name}/{field.name}") return possible_sources LeftOrRight = tuple[str, str] Join = tuple[LeftOrRight, LeftOrRight] def _find_left_or_right(source: mlc.Source) -> LeftOrRight: uid = source.uid if "/" in uid: parts = uid.split("/") return (parts[0], parts[1]) elif source.extract.column: return (uid, source.extract.column) elif source.extract.json_path: return (uid, source.extract.json_path) elif source.extract.file_property: return (uid, source.extract.file_property) else: return (uid, None) def _find_joins(fields: list[Field]) -> set[Join]: """Finds the existing joins in the fields.""" joins: set[Join] = set() for field in fields: if field.source and field.references: left = _find_left_or_right(field.source) right = _find_left_or_right(field.references) joins.add((left, right)) return joins def _handle_create_record_set(): metadata: Metadata = st.session_state[Metadata] metadata.add_record_set(RecordSet(name="new-record-set", description="")) def _handle_remove_record_set(record_set_key: int): del st.session_state[Metadata].record_sets[record_set_key] def _handle_fields_change(record_set_key: int, record_set: RecordSet): expand_record_set(record_set=record_set) data_editor_key = _data_editor_key(record_set_key, record_set) result = st.session_state[data_editor_key] # `result` has the following structure: # ``` # {'edited_rows': {1: {}}, 'added_rows': [], 'deleted_rows': []} # ``` fields = record_set.fields for field_key in result["edited_rows"]: field = fields[field_key] new_fields = result["edited_rows"][field_key] for new_field, new_value in new_fields.items(): if new_field == FieldDataFrame.NAME: field.name = new_value elif new_field == FieldDataFrame.DESCRIPTION: field.description = new_value elif new_field == FieldDataFrame.DATA_TYPE: field.data_types = [str_to_mlc_data_type(new_value)] for added_row in result["added_rows"]: data_type = str_to_mlc_data_type(added_row.get(FieldDataFrame.DATA_TYPE)) field = Field( name=added_row.get(FieldDataFrame.NAME), description=added_row.get(FieldDataFrame.DESCRIPTION), data_types=[data_type], source=mlc.Source(), references=mlc.Source(), ) st.session_state[Metadata].add_field(record_set_key, field) for field_key in result["deleted_rows"]: st.session_state[Metadata].remove_field(record_set_key, field_key) # Reset the in-line data if it exists. if record_set.data: record_set.data = [] class FieldDataFrame: """Names of the columns in the pd.DataFrame for `fields`.""" NAME = "Field name" DESCRIPTION = "Field description" DATA_TYPE = "Data type" SOURCE_UID = "Source" SOURCE_EXTRACT = "Source extract" SOURCE_TRANSFORM = "Source transform" REFERENCE_UID = "Reference" REFERENCE_EXTRACT = "Reference extract" def render_record_sets(): st.info(_INFO, icon="💡") col1, col2 = st.columns([1, 1]) with col1: with st.spinner("Generating the dataset..."): _render_left_panel() with col2: _render_right_panel() def _render_left_panel(): """Left panel: visualization of all RecordSets as expandable forms.""" record_sets = st.session_state[Metadata].record_sets record_set: RecordSet for record_set_key, record_set in enumerate(record_sets): title = f"**{record_set.name or '-'}** ({len(record_set.fields)} fields)" prefix = f"record-set-{record_set_key}" with st.expander(title, expanded=is_record_set_expanded(record_set)): col1, col2 = st.columns([1, 3]) key = f"{prefix}-name" col1.text_input( needed_field("Name"), placeholder="Name without special character.", key=key, help=f"The name of the RecordSet. {NAMES_INFO}", value=record_set.name, on_change=handle_record_set_change, args=(RecordSetEvent.NAME, record_set, key), ) key = f"{prefix}-description" col2.text_input( "Description", placeholder="Provide a description of the RecordSet.", key=key, value=record_set.description, on_change=handle_record_set_change, args=(RecordSetEvent.DESCRIPTION, record_set, key), ) key = f"{prefix}-is-enumeration" st.checkbox( "The RecordSet is an enumeration", key=key, help=( "Enumerations indicate that the RecordSet takes its values from a" " finite set. Similar to `ClassLabel` in" " [TFDS](https://www.tensorflow.org/datasets/api_docs/python/tfds/features/ClassLabel)" " or [Hugging" " Face](https://huggingface.co/docs/datasets/v2.15.0/en/package_reference/main_classes#datasets.ClassLabel)." ), value=record_set.is_enumeration, on_change=handle_record_set_change, args=(RecordSetEvent.IS_ENUMERATION, record_set, key), ) key = f"{prefix}-has-data" st.checkbox( "The RecordSet has in-line data", key=key, help=( "In-line data allows to embed data directly within the JSON-LD" " without referencing another data source." ), value=bool(record_set.data), on_change=handle_record_set_change, args=(RecordSetEvent.HAS_DATA, record_set, key), ) joins = _find_joins(record_set.fields) has_join = st.checkbox( "The RecordSet contains joins. To add a new join, add a field" " with a source in `RecordSet`/`FileSet`/`FileObject` and a reference" " to another `RecordSet`/`FileSet`/`FileObject`.", key=f"{prefix}-has-joins", value=bool(joins), disabled=True, ) if has_join: for left, right in joins: col1, col2, _, col4, col5 = st.columns([2, 2, 1, 2, 2]) col1.text_input( "Left join", disabled=True, value=left[0], key=f"{prefix}-left-join-{left[0]}-{left[1]}", ) col2.text_input( "Left key", disabled=True, value=left[1], key=f"{prefix}-left-key-{left[0]}-{left[1]}", ) col4.text_input( "Right join", disabled=True, value=right[0], key=f"{prefix}-right-join-{right[0]}-{right[1]}", ) col5.text_input( "Right key", disabled=True, value=right[1], key=f"{prefix}-right-key-{right[0]}-{right[1]}", ) names = [field.name for field in record_set.fields] descriptions = [field.description for field in record_set.fields] # TODO(https://github.com/mlcommons/croissant/issues/350): Allow to display # several data types, not only the first. data_types = [ mlc_to_str_data_type(field.data_types[0]) if field.data_types else None for field in record_set.fields ] fields = pd.DataFrame( { FieldDataFrame.NAME: names, FieldDataFrame.DESCRIPTION: descriptions, FieldDataFrame.DATA_TYPE: data_types, }, dtype=np.str_, ) data_editor_key = _data_editor_key(record_set_key, record_set) st.markdown( needed_field("Fields"), help=( "Add/delete fields by directly editing the table. **Warning**: the" " table contains information about the fields--not the data" " directly. If you wish to embed data, tick the `The RecordSet is" " an enumeration` box. To edit fields details, click the" " button `Edit fields details` below." ), ) st.data_editor( fields, use_container_width=True, num_rows="dynamic", key=data_editor_key, column_config={ FieldDataFrame.NAME: st.column_config.TextColumn( FieldDataFrame.NAME, help="Name of the field", required=True, ), FieldDataFrame.DESCRIPTION: st.column_config.TextColumn( FieldDataFrame.DESCRIPTION, help="Description of the field", required=False, ), FieldDataFrame.DATA_TYPE: st.column_config.SelectboxColumn( FieldDataFrame.DATA_TYPE, help="The Croissant type", options=STR_DATA_TYPES, required=True, ), }, on_change=_handle_fields_change, args=(record_set_key, record_set), ) result: _Result = _generate_data_with_timeout(record_set) df, exception = result.get("df"), result.get("exception") if exception is None and df is not None and not df.empty: st.markdown("Preview the data:") st.dataframe(df, use_container_width=True) # The generation is not triggered if record_set has in-line `data`. elif not record_set.data: left, right = st.columns([1, 10]) if exception: left.button( "⚠️", key=f"idea-{prefix}", on_click=lambda: _generate_data_with_timeout.clear(), help=textwrap.dedent(f"""**Error**: ``` {exception} ``` """), ) right.markdown("No preview is possible.") st.button( "Edit fields details", key=f"{prefix}-show-fields", on_click=_handle_on_click_field, args=(record_set_key, record_set), ) key = f"{prefix}-delete-record-set" button_with_confirmation( "Delete RecordSet", key=key, on_click=_handle_remove_record_set, args=(record_set_key,), ) st.button( "Create a new RecordSet", key=f"create-new-record-set", type="primary", on_click=_handle_create_record_set, ) def _render_right_panel(): """Right panel: visualization of the clicked Field.""" metadata: Metadata = st.session_state.get(Metadata) selected: SelectedRecordSet = st.session_state.get(SelectedRecordSet) if not selected: return record_set = selected.record_set record_set_key = selected.record_set_key with st.expander("**Fields**", expanded=True): if isinstance(record_set.data, list): st.markdown( f"{needed_field('Data')}. This RecordSet is marked as having in-line" " data. Please, list the data below:" ) key = f"{record_set_key}-fields-data" columns = [field.name for field in record_set.fields] st.data_editor( pd.DataFrame(record_set.data, columns=columns), use_container_width=True, num_rows="dynamic", key=key, column_config={ field.name: st.column_config.TextColumn( field.name, help=field.description, required=True, ) for field in record_set.fields }, on_change=handle_record_set_change, args=(RecordSetEvent.CHANGE_DATA, record_set, key), ) else: for field_key, field in enumerate(record_set.fields): prefix = f"{record_set_key}-{field.name}-{field_key}" col1, col2, col3 = st.columns([1, 1, 1]) key = f"{prefix}-name" col1.text_input( needed_field("Name"), placeholder="Name without special character.", key=key, help=f"The name of the field. {NAMES_INFO}", value=field.name, on_change=handle_field_change, args=(FieldEvent.NAME, field, key), ) key = f"{prefix}-description" col2.text_input( "Description", placeholder="Provide a description of the RecordSet.", key=key, on_change=handle_field_change, value=field.description, args=(FieldEvent.DESCRIPTION, field, key), ) data_type_index = None if field.data_types: data_type = field.data_types[0] if isinstance(data_type, str): data_type = term.URIRef(data_type) if data_type in MLC_DATA_TYPES: data_type_index = MLC_DATA_TYPES.index(data_type) key = f"{prefix}-datatypes" col3.selectbox( needed_field("Data type"), index=data_type_index, options=STR_DATA_TYPES, key=key, help=( "The type of the data. `Text` corresponds to" " https://schema.org/Text, etc." ), on_change=handle_field_change, args=(FieldEvent.DATA_TYPE, field, key), ) possible_sources = _get_possible_sources(metadata) render_source(record_set, field, possible_sources) render_references(record_set, field, possible_sources) st.divider() st.button( "Close", key=f"{record_set.name}-{record_set_key}-close-fields", type="primary", on_click=_handle_close_fields, )