Spaces:
Sleeping
Sleeping
import os.path | |
import logging | |
import pandas as pd | |
from pathlib import Path | |
from datetime import datetime | |
import csv | |
from utils.dedup import Dedup | |
class DatasetBase: | |
""" | |
This class store and manage all the dataset records (including the annotations and prediction) | |
""" | |
def __init__(self, config): | |
if config.records_path is None: | |
self.records = pd.DataFrame(columns=['id', 'text', 'prediction', | |
'annotation', 'metadata', 'score', 'batch_id']) | |
else: | |
self.records = pd.read_csv(config.records_path) | |
dt_string = datetime.now().strftime("%d_%m_%Y_%H_%M_%S") | |
self.name = config.name + '__' + dt_string | |
self.label_schema = config.label_schema | |
self.dedup = Dedup(config) | |
self.sample_size = config.get("sample_size", 3) | |
self.semantic_sampling = config.get("semantic_sampling", False) | |
if not config.get('dedup_new_samples', False): | |
self.remove_duplicates = self._null_remove | |
def __len__(self): | |
""" | |
Return the number of samples in the dataset. | |
""" | |
return len(self.records) | |
def __getitem__(self, batch_idx): | |
""" | |
Return the batch idx. | |
""" | |
extract_records = self.records[self.records['batch_id'] == batch_idx] | |
extract_records = extract_records.reset_index(drop=True) | |
return extract_records | |
def get_leq(self, batch_idx): | |
""" | |
Return all the records up to batch_idx (includes). | |
""" | |
extract_records = self.records[self.records['batch_id'] <= batch_idx] | |
extract_records = extract_records.reset_index(drop=True) | |
return extract_records | |
def add(self, sample_list: dict = None, batch_id: int = None, records: pd.DataFrame = None): | |
""" | |
Add records to the dataset. | |
:param sample_list: The samples to add in a dict structure (only used in case record=None) | |
:param batch_id: The batch_id for the upload records (only used in case record= None) | |
:param records: dataframes, update using pandas | |
""" | |
if records is None: | |
records = pd.DataFrame([{'id': len(self.records) + i, 'text': sample, 'batch_id': batch_id} for | |
i, sample in enumerate(sample_list)]) | |
self.records = pd.concat([self.records, records], ignore_index=True) | |
def update(self, records: pd.DataFrame): | |
""" | |
Update records in dataset. | |
""" | |
# Ignore if records is empty | |
if len(records) == 0: | |
return | |
# Set 'id' as the index for both DataFrames | |
records.set_index('id', inplace=True) | |
self.records.set_index('id', inplace=True) | |
# Update using 'id' as the key | |
self.records.update(records) | |
# Remove null annotations | |
if len(self.records.loc[self.records["annotation"]=="Discarded"]) > 0: | |
discarded_annotation_records = self.records.loc[self.records["annotation"]=="Discarded"] | |
#TODO: direct `discarded_annotation_records` to another dataset to be used later for corner-cases | |
self.records = self.records.loc[self.records["annotation"]!="Discarded"] | |
# Reset index | |
self.records.reset_index(inplace=True) | |
def modify(self, index: int, record: dict): | |
""" | |
Modify a record in the dataset. | |
""" | |
self.records[index] = record | |
def apply(self, function, column_name: str): | |
""" | |
Apply function on each record. | |
""" | |
self.records[column_name] = self.records.apply(function, axis=1) | |
def save_dataset(self, path: Path): | |
self.records.to_csv(path, index=False, quoting=csv.QUOTE_NONNUMERIC) | |
def load_dataset(self, path: Path): | |
""" | |
Loading dataset | |
:param path: path for the csv | |
""" | |
if os.path.isfile(path): | |
self.records = pd.read_csv(path, dtype={'annotation': str, 'prediction': str, 'batch_id': int}) | |
else: | |
logging.warning('Dataset dump not found, initializing from zero') | |
def remove_duplicates(self, samples: list) -> list: | |
""" | |
Remove (soft) duplicates from the given samples | |
:param samples: The samples | |
:return: The samples without duplicates | |
""" | |
dd = self.dedup.copy() | |
df = pd.DataFrame(samples, columns=['text']) | |
df_dedup = dd.sample(df, operation_function=min) | |
return df_dedup['text'].tolist() | |
def _null_remove(self, samples: list) -> list: | |
# Identity function that returns the input unmodified | |
return samples | |
def sample_records(self, n: int = None) -> pd.DataFrame: | |
""" | |
Return a sample of the records after semantic clustering | |
:param n: The number of samples to return | |
:return: A sample of the records | |
""" | |
n = n or self.sample_size | |
if self.semantic_sampling: | |
dd = self.dedup.copy() | |
df_samples = dd.sample(self.records).head(n) | |
if len(df_samples) < n: | |
df_samples = self.records.head(n) | |
else: | |
df_samples = self.records.sample(n) | |
return df_samples | |
def samples_to_text(records: pd.DataFrame) -> str: | |
""" | |
Return a string that organize the samples for a meta-prompt | |
:param records: The samples for the step | |
:return: A string that contains the organized samples | |
""" | |
txt_res = '##\n' | |
for i, row in records.iterrows(): | |
txt_res += f"Sample:\n {row.text}\n#\n" | |
return txt_res | |