Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,666 Bytes
e1aa577 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os.path
import logging
import pandas as pd
from pathlib import Path
from datetime import datetime
import csv
from utils.dedup import Dedup
class DatasetBase:
"""
This class store and manage all the dataset records (including the annotations and prediction)
"""
def __init__(self, config):
if config.records_path is None:
self.records = pd.DataFrame(columns=['id', 'text', 'prediction',
'annotation', 'metadata', 'score', 'batch_id'])
else:
self.records = pd.read_csv(config.records_path)
dt_string = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
self.name = config.name + '__' + dt_string
self.label_schema = config.label_schema
self.dedup = Dedup(config)
self.sample_size = config.get("sample_size", 3)
self.semantic_sampling = config.get("semantic_sampling", False)
if not config.get('dedup_new_samples', False):
self.remove_duplicates = self._null_remove
def __len__(self):
"""
Return the number of samples in the dataset.
"""
return len(self.records)
def __getitem__(self, batch_idx):
"""
Return the batch idx.
"""
extract_records = self.records[self.records['batch_id'] == batch_idx]
extract_records = extract_records.reset_index(drop=True)
return extract_records
def get_leq(self, batch_idx):
"""
Return all the records up to batch_idx (includes).
"""
extract_records = self.records[self.records['batch_id'] <= batch_idx]
extract_records = extract_records.reset_index(drop=True)
return extract_records
def add(self, sample_list: dict = None, batch_id: int = None, records: pd.DataFrame = None):
"""
Add records to the dataset.
:param sample_list: The samples to add in a dict structure (only used in case record=None)
:param batch_id: The batch_id for the upload records (only used in case record= None)
:param records: dataframes, update using pandas
"""
if records is None:
records = pd.DataFrame([{'id': len(self.records) + i, 'text': sample, 'batch_id': batch_id} for
i, sample in enumerate(sample_list)])
self.records = pd.concat([self.records, records], ignore_index=True)
def update(self, records: pd.DataFrame):
"""
Update records in dataset.
"""
# Ignore if records is empty
if len(records) == 0:
return
# Set 'id' as the index for both DataFrames
records.set_index('id', inplace=True)
self.records.set_index('id', inplace=True)
# Update using 'id' as the key
self.records.update(records)
# Remove null annotations
if len(self.records.loc[self.records["annotation"]=="Discarded"]) > 0:
discarded_annotation_records = self.records.loc[self.records["annotation"]=="Discarded"]
#TODO: direct `discarded_annotation_records` to another dataset to be used later for corner-cases
self.records = self.records.loc[self.records["annotation"]!="Discarded"]
# Reset index
self.records.reset_index(inplace=True)
def modify(self, index: int, record: dict):
"""
Modify a record in the dataset.
"""
self.records[index] = record
def apply(self, function, column_name: str):
"""
Apply function on each record.
"""
self.records[column_name] = self.records.apply(function, axis=1)
def save_dataset(self, path: Path):
self.records.to_csv(path, index=False, quoting=csv.QUOTE_NONNUMERIC)
def load_dataset(self, path: Path):
"""
Loading dataset
:param path: path for the csv
"""
if os.path.isfile(path):
self.records = pd.read_csv(path, dtype={'annotation': str, 'prediction': str, 'batch_id': int})
else:
logging.warning('Dataset dump not found, initializing from zero')
def remove_duplicates(self, samples: list) -> list:
"""
Remove (soft) duplicates from the given samples
:param samples: The samples
:return: The samples without duplicates
"""
dd = self.dedup.copy()
df = pd.DataFrame(samples, columns=['text'])
df_dedup = dd.sample(df, operation_function=min)
return df_dedup['text'].tolist()
def _null_remove(self, samples: list) -> list:
# Identity function that returns the input unmodified
return samples
def sample_records(self, n: int = None) -> pd.DataFrame:
"""
Return a sample of the records after semantic clustering
:param n: The number of samples to return
:return: A sample of the records
"""
n = n or self.sample_size
if self.semantic_sampling:
dd = self.dedup.copy()
df_samples = dd.sample(self.records).head(n)
if len(df_samples) < n:
df_samples = self.records.head(n)
else:
df_samples = self.records.sample(n)
return df_samples
@staticmethod
def samples_to_text(records: pd.DataFrame) -> str:
"""
Return a string that organize the samples for a meta-prompt
:param records: The samples for the step
:return: A string that contains the organized samples
"""
txt_res = '##\n'
for i, row in records.iterrows():
txt_res += f"Sample:\n {row.text}\n#\n"
return txt_res
|