Raise error for train and validation mismatch

#459
Files changed (1) hide show
  1. geneformer/mtl/data.py +117 -105
geneformer/mtl/data.py CHANGED
@@ -1,150 +1,162 @@
1
  import os
2
-
3
  from .collators import DataCollatorForMultitaskCellClassification
4
  from .imports import *
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
 
8
  try:
9
  dataset = load_from_disk(dataset_path)
10
 
 
11
  task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
12
  task_to_column = dict(zip(task_names, config["task_columns"]))
13
  config["task_names"] = task_names
14
 
15
- if not is_test:
16
- available_columns = set(dataset.column_names)
17
- for column in task_to_column.values():
18
- if column not in available_columns:
19
- raise KeyError(
20
- f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
21
- )
22
-
23
- label_mappings = {}
24
- task_label_mappings = {}
25
- cell_id_mapping = {}
26
- num_labels_list = []
27
-
28
- # Load or create task label mappings
29
- if not is_test:
30
- for task, column in task_to_column.items():
31
- unique_values = sorted(set(dataset[column])) # Ensure consistency
32
- label_mappings[column] = {
33
- label: idx for idx, label in enumerate(unique_values)
34
- }
35
- task_label_mappings[task] = label_mappings[column]
36
- num_labels_list.append(len(unique_values))
37
-
38
- # Print the mappings for each task with dataset type prefix
39
- for task, mapping in task_label_mappings.items():
40
- print(
41
- f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
42
- ) # sanity check, for train/validation splits
43
-
44
- # Save the task label mappings as a pickle file
45
- with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
46
- pickle.dump(task_label_mappings, f)
47
- else:
48
- # Load task label mappings from pickle file for test data
49
- with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
50
- task_label_mappings = pickle.load(f)
51
-
52
- # Infer num_labels_list from task_label_mappings
53
- for task, mapping in task_label_mappings.items():
54
- num_labels_list.append(len(mapping))
55
-
56
- # Store unique cell IDs in a separate dictionary
57
- for idx, record in enumerate(dataset):
58
- cell_id = record.get("unique_cell_id", idx)
59
- cell_id_mapping[idx] = cell_id
60
-
61
- # Transform records to the desired format
62
- transformed_dataset = []
63
- for idx, record in enumerate(dataset):
64
- transformed_record = {}
65
- transformed_record["input_ids"] = torch.tensor(
66
- record["input_ids"], dtype=torch.long
67
- )
68
 
69
- # Use index-based cell ID for internal tracking
70
- transformed_record["cell_id"] = idx
71
 
72
- if not is_test:
73
- # Prepare labels
74
- label_dict = {}
75
- for task, column in task_to_column.items():
76
- label_value = record[column]
77
- label_index = task_label_mappings[task][label_value]
78
- label_dict[task] = label_index
79
- transformed_record["label"] = label_dict
80
- else:
81
- # Create dummy labels for test data
82
- label_dict = {task: -1 for task in config["task_names"]}
83
- transformed_record["label"] = label_dict
84
 
85
- transformed_dataset.append(transformed_record)
 
 
 
86
 
87
  return transformed_dataset, cell_id_mapping, num_labels_list
 
88
  except KeyError as e:
89
- print(f"Missing configuration or dataset key: {e}")
90
  except Exception as e:
91
- print(f"An error occurred while loading or preprocessing data: {e}")
92
- return None, None, None
93
 
94
 
95
  def preload_and_process_data(config):
96
- # Load and preprocess data once
97
- train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
98
- config["train_path"], config, dataset_type="train"
99
- )
100
- val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
101
- config["val_path"], config, dataset_type="validation"
102
- )
103
- return (
104
- train_dataset,
105
- train_cell_id_mapping,
106
- val_dataset,
107
- val_cell_id_mapping,
108
- num_labels_list,
109
- )
110
 
 
 
111
 
112
- def get_data_loader(preprocessed_dataset, batch_size):
113
- nproc = os.cpu_count() ### I/O operations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- data_collator = DataCollatorForMultitaskCellClassification()
116
 
117
- loader = DataLoader(
 
 
118
  preprocessed_dataset,
119
  batch_size=batch_size,
120
  shuffle=True,
121
- collate_fn=data_collator,
122
- num_workers=nproc,
123
  pin_memory=True,
124
  )
125
- return loader
126
 
127
 
128
  def preload_data(config):
129
- # Preprocessing the data before the Optuna trials start
130
- train_loader = get_data_loader("train", config)
131
- val_loader = get_data_loader("val", config)
132
  return train_loader, val_loader
133
 
134
 
135
  def load_and_preprocess_test_data(config):
136
- """
137
- Load and preprocess test data, treating it as unlabeled.
138
- """
139
  return load_and_preprocess_data(config["test_path"], config, is_test=True)
140
 
141
 
142
  def prepare_test_loader(config):
143
- """
144
- Prepare DataLoader for the test dataset.
145
- """
146
- test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
147
- config
148
- )
149
  test_loader = get_data_loader(test_dataset, config["batch_size"])
150
  return test_loader, cell_id_mapping, num_labels_list
 
1
  import os
 
2
  from .collators import DataCollatorForMultitaskCellClassification
3
  from .imports import *
4
 
5
+ def validate_columns(dataset, required_columns, dataset_type):
6
+ """Ensures required columns are present in the dataset."""
7
+ missing_columns = [col for col in required_columns if col not in dataset.column_names]
8
+ if missing_columns:
9
+ raise KeyError(
10
+ f"Missing columns in {dataset_type} dataset: {missing_columns}. "
11
+ f"Available columns: {dataset.column_names}"
12
+ )
13
+
14
+
15
+ def create_label_mappings(dataset, task_to_column):
16
+ """Creates label mappings for the dataset."""
17
+ task_label_mappings = {}
18
+ num_labels_list = []
19
+ for task, column in task_to_column.items():
20
+ unique_values = sorted(set(dataset[column]))
21
+ mapping = {label: idx for idx, label in enumerate(unique_values)}
22
+ task_label_mappings[task] = mapping
23
+ num_labels_list.append(len(unique_values))
24
+ return task_label_mappings, num_labels_list
25
+
26
+
27
+ def save_label_mappings(mappings, path):
28
+ """Saves label mappings to a pickle file."""
29
+ with open(path, "wb") as f:
30
+ pickle.dump(mappings, f)
31
+
32
+
33
+ def load_label_mappings(path):
34
+ """Loads label mappings from a pickle file."""
35
+ with open(path, "rb") as f:
36
+ return pickle.load(f)
37
+
38
+
39
+ def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
40
+ """Transforms the dataset to the required format."""
41
+ transformed_dataset = []
42
+ cell_id_mapping = {}
43
+
44
+ for idx, record in enumerate(dataset):
45
+ transformed_record = {
46
+ "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
47
+ "cell_id": idx, # Index-based cell ID
48
+ }
49
+
50
+ if not is_test:
51
+ label_dict = {
52
+ task: task_label_mappings[task][record[column]]
53
+ for task, column in task_to_column.items()
54
+ }
55
+ else:
56
+ label_dict = {task: -1 for task in config["task_names"]}
57
+
58
+ transformed_record["label"] = label_dict
59
+ transformed_dataset.append(transformed_record)
60
+ cell_id_mapping[idx] = record.get("unique_cell_id", idx)
61
+
62
+ return transformed_dataset, cell_id_mapping
63
+
64
 
65
  def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
66
+ """Main function to load and preprocess data."""
67
  try:
68
  dataset = load_from_disk(dataset_path)
69
 
70
+ # Setup task and column mappings
71
  task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
72
  task_to_column = dict(zip(task_names, config["task_columns"]))
73
  config["task_names"] = task_names
74
 
75
+ label_mappings_path = os.path.join(
76
+ config["results_dir"],
77
+ f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
78
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ if not is_test:
81
+ validate_columns(dataset, task_to_column.values(), dataset_type)
82
 
83
+ # Create and save label mappings
84
+ task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
85
+ save_label_mappings(task_label_mappings, label_mappings_path)
86
+ else:
87
+ # Load existing mappings for test data
88
+ task_label_mappings = load_label_mappings(label_mappings_path)
89
+ num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
 
 
 
 
 
90
 
91
+ # Transform dataset
92
+ transformed_dataset, cell_id_mapping = transform_dataset(
93
+ dataset, task_to_column, task_label_mappings, config, is_test
94
+ )
95
 
96
  return transformed_dataset, cell_id_mapping, num_labels_list
97
+
98
  except KeyError as e:
99
+ raise ValueError(f"Configuration error or dataset key missing: {e}")
100
  except Exception as e:
101
+ raise RuntimeError(f"Error during data loading or preprocessing: {e}")
 
102
 
103
 
104
  def preload_and_process_data(config):
105
+ """Preloads and preprocesses train and validation datasets."""
106
+ # Process train data and save mappings
107
+ train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Process validation data and save mappings
110
+ val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
111
 
112
+ # Validate that the mappings match
113
+ validate_label_mappings(config)
114
+
115
+ return (*train_data, *val_data[:2]) # Return train and val data along with mappings
116
+
117
+
118
+ def validate_label_mappings(config):
119
+ """Ensures train and validation label mappings are consistent."""
120
+ train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
121
+ val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
122
+ train_mappings = load_label_mappings(train_mappings_path)
123
+ val_mappings = load_label_mappings(val_mappings_path)
124
+
125
+ for task_name in config["task_names"]:
126
+ if train_mappings[task_name] != val_mappings[task_name]:
127
+ raise ValueError(
128
+ f"Mismatch in label mappings for task '{task_name}'.\n"
129
+ f"Train Mapping: {train_mappings[task_name]}\n"
130
+ f"Validation Mapping: {val_mappings[task_name]}"
131
+ )
132
 
 
133
 
134
+ def get_data_loader(preprocessed_dataset, batch_size):
135
+ """Creates a DataLoader with optimal settings."""
136
+ return DataLoader(
137
  preprocessed_dataset,
138
  batch_size=batch_size,
139
  shuffle=True,
140
+ collate_fn=DataCollatorForMultitaskCellClassification(),
141
+ num_workers=os.cpu_count(),
142
  pin_memory=True,
143
  )
 
144
 
145
 
146
  def preload_data(config):
147
+ """Preprocesses train and validation data for trials."""
148
+ train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
149
+ val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
150
  return train_loader, val_loader
151
 
152
 
153
  def load_and_preprocess_test_data(config):
154
+ """Loads and preprocesses test data."""
 
 
155
  return load_and_preprocess_data(config["test_path"], config, is_test=True)
156
 
157
 
158
  def prepare_test_loader(config):
159
+ """Prepares DataLoader for test data."""
160
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
 
 
 
 
161
  test_loader = get_data_loader(test_dataset, config["batch_size"])
162
  return test_loader, cell_id_mapping, num_labels_list