Fix bug in selecting a gene with "aggregate_data" option

#312
Files changed (1) hide show
  1. geneformer/in_silico_perturber_stats.py +396 -692
geneformer/in_silico_perturber_stats.py CHANGED
@@ -1,180 +1,131 @@
1
  """
2
  Geneformer in silico perturber stats generator.
3
 
4
- **Usage:**
5
-
6
- .. code-block :: python
7
-
8
- >>> from geneformer import InSilicoPerturberStats
9
- >>> ispstats = InSilicoPerturberStats(mode="goal_state_shift",
10
- ... cell_states_to_model={"state_key": "disease",
11
- ... "start_state": "dcm",
12
- ... "goal_state": "nf",
13
- ... "alt_states": ["hcm", "other1", "other2"]})
14
- >>> ispstats.get_stats("path/to/input_data",
15
- ... None,
16
- ... "path/to/output_directory",
17
- ... "output_prefix")
18
-
19
- **Description:**
20
-
21
- | Aggregates data or calculates stats for in silico perturbations based on type of statistics specified in InSilicoPerturberStats.
22
- | Input data is raw in silico perturbation results in the form of dictionaries outputted by ``in_silico_perturber``.
23
-
24
  """
25
 
26
 
27
- import logging
28
  import os
29
- import pickle
30
- import random
31
- from pathlib import Path
32
-
33
  import numpy as np
34
  import pandas as pd
 
 
35
  import statsmodels.stats.multitest as smt
 
36
  from scipy.stats import ranksums
37
  from sklearn.mixture import GaussianMixture
38
- from tqdm.auto import tqdm, trange
 
 
39
 
40
- from .perturber_utils import flatten_list, validate_cell_states_to_model
41
  from .tokenizer import TOKEN_DICTIONARY_FILE
42
 
43
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
44
 
45
  logger = logging.getLogger(__name__)
46
 
47
-
48
  # invert dictionary keys/values
49
  def invert_dict(dictionary):
50
  return {v: k for k, v in dictionary.items()}
51
 
52
-
53
  def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
54
  if cell_or_gene_emb == "cell":
55
- cell_emb_dict = {
56
- k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
57
- }
58
  return [cell_emb_dict]
59
  elif cell_or_gene_emb == "gene":
60
- if anchor_token is None:
61
- gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
62
- else:
63
- gene_emb_dict = {
64
- k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
65
- }
66
  return [gene_emb_dict]
67
 
68
 
 
 
 
69
  # read raw dictionary files
70
- def read_dictionaries(
71
- input_data_directory,
72
- cell_or_gene_emb,
73
- anchor_token,
74
- cell_states_to_model,
75
- pickle_suffix,
76
- ):
77
  file_found = False
78
  file_path_list = []
79
  if cell_states_to_model is None:
80
  dict_list = []
81
  else:
82
- validate_cell_states_to_model(cell_states_to_model)
83
- cell_states_to_model_valid = {
84
- state: value
85
- for state, value in cell_states_to_model.items()
86
- if state != "state_key"
87
- and cell_states_to_model[state] is not None
88
- and cell_states_to_model[state] != []
89
- }
90
- cell_states_list = []
91
- # flatten all state values into list
92
- for state in cell_states_to_model_valid:
93
- value = cell_states_to_model_valid[state]
94
- if isinstance(value, list):
95
- cell_states_list += value
96
- else:
97
- cell_states_list.append(value)
98
- state_dict = {state_value: dict() for state_value in cell_states_list}
99
  for file in os.listdir(input_data_directory):
100
- # process only files with given suffix (e.g. "_raw.pickle")
101
  if file.endswith(pickle_suffix):
102
  file_found = True
103
  file_path_list += [f"{input_data_directory}/{file}"]
104
  for file_path in tqdm(file_path_list):
105
- with open(file_path, "rb") as fp:
106
  cos_sims_dict = pickle.load(fp)
107
  if cell_states_to_model is None:
108
  dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
109
  else:
110
- for state_value in cell_states_list:
111
- new_dict = read_dict(
112
- cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
113
- )[0]
114
- for key in new_dict:
115
- try:
116
- state_dict[state_value][key] += new_dict[key]
117
- except KeyError:
118
- state_dict[state_value][key] = new_dict[key]
119
  if not file_found:
120
  logger.error(
121
- "No raw data for processing found within provided directory. "
122
- "Please ensure data files end with '{pickle_suffix}'."
123
- )
124
  raise
125
  if cell_states_to_model is None:
126
  return dict_list
127
  else:
128
  return state_dict
129
 
130
-
131
  # get complete gene list
132
- def get_gene_list(dict_list, mode):
133
  if mode == "cell":
134
  position = 0
135
  elif mode == "gene":
136
  position = 1
137
  gene_set = set()
138
- if isinstance(dict_list, list):
139
- for dict_i in dict_list:
140
- gene_set.update([k[position] for k, v in dict_i.items() if v])
141
- elif isinstance(dict_list, dict):
142
- for state, dict_i in dict_list.items():
143
- gene_set.update([k[position] for k, v in dict_i.items() if v])
144
- else:
145
- logger.error(
146
- "dict_list should be a list, or if modeling shift to goal states, a dict. "
147
- f"{type(dict_list)} is not the correct format."
148
- )
149
- raise
150
  gene_list = list(gene_set)
151
  if mode == "gene":
152
  gene_list.remove("cell_emb")
153
  gene_list.sort()
154
  return gene_list
155
 
156
-
157
  def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
158
  try:
159
  return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
160
- except TypeError:
161
- return gene_token_id_dict.get(token_tuple, np.nan)
162
-
163
 
164
  def n_detections(token, dict_list, mode, anchor_token):
165
  cos_sim_megalist = []
166
  for dict_i in dict_list:
167
  if mode == "cell":
168
- cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
169
  elif mode == "gene":
170
- cos_sim_megalist += dict_i.get((anchor_token, token), [])
171
  return len(cos_sim_megalist)
172
 
173
-
174
  def get_fdr(pvalues):
175
  return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
176
 
177
-
178
  def get_impact_component(test_value, gaussian_mixture_model):
179
  impact_border = gaussian_mixture_model.means_[0][0]
180
  nonimpact_border = gaussian_mixture_model.means_[1][0]
@@ -190,357 +141,237 @@ def get_impact_component(test_value, gaussian_mixture_model):
190
  impact_component = 1
191
  return impact_component
192
 
193
-
194
  # aggregate data for single perturbation in multiple cells
195
- def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
196
- names = ["Cosine_shift"]
197
  cos_sims_full_df = pd.DataFrame(columns=names)
198
 
199
  cos_shift_data = []
200
  token = cos_sims_df["Gene"][0]
201
  for dict_i in dict_list:
202
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
203
  cos_sims_full_df["Cosine_shift"] = cos_shift_data
204
- return cos_sims_full_df
205
-
206
-
207
- def find(variable, x):
208
- try:
209
- if x in variable: # Test if variable is iterable and contains x
210
- return True
211
- except (ValueError, TypeError):
212
- return x == variable # Test if variable is x if non-iterable
213
-
214
-
215
- def isp_aggregate_gene_shifts(
216
- cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
217
- ):
218
- cos_shift_data = dict()
219
- for i in trange(cos_sims_df.shape[0]):
220
- token = cos_sims_df["Gene"][i]
221
- for dict_i in dict_list:
222
- affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
223
- for key in affected_pairs:
224
- if key in cos_shift_data.keys():
225
- cos_shift_data[key] += dict_i.get(key, [])
226
- else:
227
- cos_shift_data[key] = dict_i.get(key, [])
228
-
229
- cos_data_mean = {
230
- k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
231
- }
232
- cos_sims_full_df = pd.DataFrame()
233
- cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
234
- cos_sims_full_df["Gene_name"] = [
235
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
236
- for k, v in cos_data_mean.items()
237
- ]
238
- cos_sims_full_df["Ensembl_ID"] = [
239
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
240
- for k, v in cos_data_mean.items()
241
- ]
242
-
243
- cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
244
- cos_sims_full_df["Affected_gene_name"] = [
245
- gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
246
- for token in cos_sims_full_df["Affected"]
247
- ]
248
- cos_sims_full_df["Affected_Ensembl_ID"] = [
249
- gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
250
- ]
251
- cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
252
- cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
253
- cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
254
-
255
- specific_val = "cell_emb"
256
- cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
257
- # reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
258
- cos_sims_full_df = cos_sims_full_df.sort_values(
259
- by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
260
- ).drop("temp", axis=1)
261
-
262
- return cos_sims_full_df
263
-
264
 
265
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
266
- def isp_stats_to_goal_state(
267
- cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
268
- ):
269
- if (
270
- ("alt_states" not in cell_states_to_model.keys())
271
- or (len(cell_states_to_model["alt_states"]) == 0)
272
- or (cell_states_to_model["alt_states"] == [None])
273
- ):
274
  alt_end_state_exists = False
275
- elif (len(cell_states_to_model["alt_states"]) > 0) and (
276
- cell_states_to_model["alt_states"] != [None]
277
- ):
278
  alt_end_state_exists = True
279
-
280
  # for single perturbation in multiple cells, there are no random perturbations to compare to
281
  if genes_perturbed != "all":
282
- cos_sims_full_df = pd.DataFrame()
283
-
284
- cos_shift_data_end = []
 
 
 
 
285
  token = cos_sims_df["Gene"][0]
286
- cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
287
- (token, "cell_emb"), []
288
- )
289
- cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
290
- if alt_end_state_exists is True:
291
- for alt_state in cell_states_to_model["alt_states"]:
292
- cos_shift_data_alt_state = []
293
- cos_shift_data_alt_state += result_dict.get(alt_state).get(
294
- (token, "cell_emb"), []
295
- )
296
- cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
297
- np.mean(cos_shift_data_alt_state)
298
- ]
299
-
300
  # sort by shift to desired state
301
- cos_sims_full_df = cos_sims_full_df.sort_values(
302
- by=["Shift_to_goal_end"], ascending=[False]
303
- )
304
- return cos_sims_full_df
305
-
306
  elif genes_perturbed == "all":
307
- goal_end_random_megalist = []
308
- if alt_end_state_exists is True:
309
- alt_end_state_random_dict = {
310
- alt_state: [] for alt_state in cell_states_to_model["alt_states"]
311
- }
312
  for i in trange(cos_sims_df.shape[0]):
313
  token = cos_sims_df["Gene"][i]
314
- goal_end_random_megalist += result_dict[
315
- cell_states_to_model["goal_state"]
316
- ].get((token, "cell_emb"), [])
317
- if alt_end_state_exists is True:
318
- for alt_state in cell_states_to_model["alt_states"]:
319
- alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
320
- (token, "cell_emb"), []
321
- )
322
 
323
  # downsample to improve speed of ranksums
324
  if len(goal_end_random_megalist) > 100_000:
325
  random.seed(42)
326
- goal_end_random_megalist = random.sample(
327
- goal_end_random_megalist, k=100_000
328
- )
329
- if alt_end_state_exists is True:
330
- for alt_state in cell_states_to_model["alt_states"]:
331
- if len(alt_end_state_random_dict[alt_state]) > 100_000:
332
- random.seed(42)
333
- alt_end_state_random_dict[alt_state] = random.sample(
334
- alt_end_state_random_dict[alt_state], k=100_000
335
- )
336
-
337
- names = [
338
- "Gene",
339
- "Gene_name",
340
- "Ensembl_ID",
341
- "Shift_to_goal_end",
342
- "Goal_end_vs_random_pval",
343
- ]
344
- if alt_end_state_exists is True:
345
- [
346
- names.append(f"Shift_to_alt_end_{alt_state}")
347
- for alt_state in cell_states_to_model["alt_states"]
348
- ]
349
- names.append(names.pop(names.index("Goal_end_vs_random_pval")))
350
- [
351
- names.append(f"Alt_end_vs_random_pval_{alt_state}")
352
- for alt_state in cell_states_to_model["alt_states"]
353
- ]
354
  cos_sims_full_df = pd.DataFrame(columns=names)
355
 
356
- n_detections_dict = dict()
357
  for i in trange(cos_sims_df.shape[0]):
358
  token = cos_sims_df["Gene"][i]
359
  name = cos_sims_df["Gene_name"][i]
360
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
361
- goal_end_cos_sim_megalist = result_dict[
362
- cell_states_to_model["goal_state"]
363
- ].get((token, "cell_emb"), [])
364
- n_detections_dict[token] = len(goal_end_cos_sim_megalist)
365
- mean_goal_end = np.mean(goal_end_cos_sim_megalist)
366
- pval_goal_end = ranksums(
367
- goal_end_random_megalist, goal_end_cos_sim_megalist
368
- ).pvalue
369
-
370
- if alt_end_state_exists is True:
371
- alt_end_state_dict = {
372
- alt_state: [] for alt_state in cell_states_to_model["alt_states"]
373
- }
374
- for alt_state in cell_states_to_model["alt_states"]:
375
- alt_end_state_dict[alt_state] = result_dict[alt_state].get(
376
- (token, "cell_emb"), []
377
- )
378
- alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
379
- alt_end_state_dict[alt_state]
380
- )
381
- alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
382
- alt_end_state_random_dict[alt_state],
383
- alt_end_state_dict[alt_state],
384
- ).pvalue
385
 
386
- results_dict = dict()
387
- results_dict["Gene"] = token
388
- results_dict["Gene_name"] = name
389
- results_dict["Ensembl_ID"] = ensembl_id
390
- results_dict["Shift_to_goal_end"] = mean_goal_end
391
- results_dict["Goal_end_vs_random_pval"] = pval_goal_end
392
- if alt_end_state_exists is True:
393
- for alt_state in cell_states_to_model["alt_states"]:
394
- results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
395
- f"{alt_state}_mean"
396
- ]
397
- results_dict[
398
- f"Alt_end_vs_random_pval_{alt_state}"
399
- ] = alt_end_state_dict[f"{alt_state}_pval"]
400
 
401
- cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
402
- cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
 
 
 
 
 
403
 
404
- cos_sims_full_df["Goal_end_FDR"] = get_fdr(
405
- list(cos_sims_full_df["Goal_end_vs_random_pval"])
406
- )
407
- if alt_end_state_exists is True:
408
- for alt_state in cell_states_to_model["alt_states"]:
409
- cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
410
- list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
411
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  # quantify number of detections of each gene
414
- cos_sims_full_df["N_Detections"] = [
415
- n_detections_dict[token] for token in cos_sims_full_df["Gene"]
416
- ]
417
-
418
- # sort by shift to desired state
419
- cos_sims_full_df["Sig"] = [
420
- 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
421
- ]
422
- cos_sims_full_df = cos_sims_full_df.sort_values(
423
- by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
424
- ascending=[False, False, True],
425
- )
426
-
427
  return cos_sims_full_df
428
 
429
-
430
  # stats comparing cos sim shifts of test perturbations vs null distribution
431
  def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
432
  cos_sims_full_df = cos_sims_df.copy()
433
 
434
  cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
435
  cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
436
- cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
437
- cos_sims_df.shape[0], dtype=float
438
- )
439
  cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
440
  cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
441
- cos_sims_full_df["N_Detections_test"] = np.zeros(
442
- cos_sims_df.shape[0], dtype="uint32"
443
- )
444
- cos_sims_full_df["N_Detections_null"] = np.zeros(
445
- cos_sims_df.shape[0], dtype="uint32"
446
- )
447
-
448
  for i in trange(cos_sims_df.shape[0]):
449
  token = cos_sims_df["Gene"][i]
450
  test_shifts = []
451
  null_shifts = []
452
-
453
  for dict_i in dict_list:
454
- test_shifts += dict_i.get((token, "cell_emb"), [])
455
 
456
  for dict_i in null_dict_list:
457
- null_shifts += dict_i.get((token, "cell_emb"), [])
458
-
459
  cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
460
  cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
461
- cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
462
- test_shifts
463
- ) - np.mean(null_shifts)
464
- cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
465
- test_shifts, null_shifts, nan_policy="omit"
466
- ).pvalue
467
  # remove nan values
468
- cos_sims_full_df.Test_vs_null_pval = np.where(
469
- np.isnan(cos_sims_full_df.Test_vs_null_pval),
470
- 1,
471
- cos_sims_full_df.Test_vs_null_pval,
472
- )
473
  cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
474
  cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
475
 
476
- cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
477
- cos_sims_full_df["Test_vs_null_pval"]
478
- )
479
-
480
- cos_sims_full_df["Sig"] = [
481
- 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
482
- ]
483
- cos_sims_full_df = cos_sims_full_df.sort_values(
484
- by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
485
- ascending=[False, False, True],
486
- )
487
  return cos_sims_full_df
488
 
489
-
490
  # stats for identifying perturbations with largest effect within a given set of cells
491
  # fits a mixture model to 2 components (impact vs. non-impact) and
492
  # reports the most likely component for each test perturbation
493
  # Note: because assumes given perturbation has a consistent effect in the cells tested,
494
  # we recommend only using the mixture model strategy with uniform cell populations
495
  def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
496
- names = ["Gene", "Gene_name", "Ensembl_ID"]
497
-
 
 
 
498
  if combos == 0:
499
  names += ["Test_avg_shift"]
500
  elif combos == 1:
501
- names += [
502
- "Anchor_shift",
503
- "Test_token_shift",
504
- "Sum_of_indiv_shifts",
505
- "Combo_shift",
506
- "Combo_minus_sum_shift",
507
- ]
508
-
509
- names += ["Impact_component", "Impact_component_percent"]
510
 
511
  cos_sims_full_df = pd.DataFrame(columns=names)
512
  avg_values = []
513
  gene_names = []
514
-
515
  for i in trange(cos_sims_df.shape[0]):
516
  token = cos_sims_df["Gene"][i]
517
  name = cos_sims_df["Gene_name"][i]
518
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
519
  cos_shift_data = []
520
-
521
  for dict_i in dict_list:
522
  if (combos == 0) and (anchor_token is not None):
523
- cos_shift_data += dict_i.get((anchor_token, token), [])
524
  else:
525
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
526
-
527
  # Extract values for current gene
528
  if combos == 0:
529
  test_values = cos_shift_data
530
  elif combos == 1:
531
  test_values = []
532
  for tup in cos_shift_data:
533
- test_values.append(tup[2])
534
-
535
  if len(test_values) > 0:
536
  avg_value = np.mean(test_values)
537
  avg_values.append(avg_value)
538
  gene_names.append(name)
539
-
540
  # fit Gaussian mixture model to dataset of mean for each gene
541
  avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
542
  gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
543
-
544
  for i in trange(cos_sims_df.shape[0]):
545
  token = cos_sims_df["Gene"][i]
546
  name = cos_sims_df["Gene_name"][i]
@@ -549,95 +380,72 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
549
 
550
  for dict_i in dict_list:
551
  if (combos == 0) and (anchor_token is not None):
552
- cos_shift_data += dict_i.get((anchor_token, token), [])
553
  else:
554
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
555
-
556
  if combos == 0:
557
  mean_test = np.mean(cos_shift_data)
558
- impact_components = [
559
- get_impact_component(value, gm) for value in cos_shift_data
560
- ]
561
  elif combos == 1:
562
- anchor_cos_sim_megalist = [
563
- anchor for anchor, token, combo in cos_shift_data
564
- ]
565
- token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
566
- anchor_plus_token_cos_sim_megalist = [
567
- 1 - ((1 - anchor) + (1 - token))
568
- for anchor, token, combo in cos_shift_data
569
- ]
570
- combo_anchor_token_cos_sim_megalist = [
571
- combo for anchor, token, combo in cos_shift_data
572
- ]
573
- combo_minus_sum_cos_sim_megalist = [
574
- combo - (1 - ((1 - anchor) + (1 - token)))
575
- for anchor, token, combo in cos_shift_data
576
- ]
577
 
578
  mean_anchor = np.mean(anchor_cos_sim_megalist)
579
  mean_token = np.mean(token_cos_sim_megalist)
580
  mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
581
  mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
582
  mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
583
-
584
- impact_components = [
585
- get_impact_component(value, gm)
586
- for value in combo_anchor_token_cos_sim_megalist
587
- ]
588
-
589
- impact_component = get_impact_component(mean_test, gm)
590
- impact_component_percent = np.mean(impact_components) * 100
591
-
592
- data_i = [token, name, ensembl_id]
593
  if combos == 0:
594
  data_i += [mean_test]
595
  elif combos == 1:
596
- data_i += [
597
- mean_anchor,
598
- mean_token,
599
- mean_sum,
600
- mean_test,
601
- mean_combo_minus_sum,
602
- ]
603
- data_i += [impact_component, impact_component_percent]
604
-
605
- cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
606
- cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
607
-
608
  # quantify number of detections of each gene
609
- cos_sims_full_df["N_Detections"] = [
610
- n_detections(i, dict_list, "gene", anchor_token)
611
- for i in cos_sims_full_df["Gene"]
612
- ]
613
-
614
  if combos == 0:
615
- cos_sims_full_df = cos_sims_full_df.sort_values(
616
- by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
617
- )
618
  elif combos == 1:
619
- cos_sims_full_df = cos_sims_full_df.sort_values(
620
- by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
621
- )
622
  return cos_sims_full_df
623
 
624
-
625
  class InSilicoPerturberStats:
626
  valid_option_dict = {
627
- "mode": {
628
- "goal_state_shift",
629
- "vs_null",
630
- "mixture_model",
631
- "aggregate_data",
632
- "aggregate_gene_shifts",
633
- },
634
- "genes_perturbed": {"all", list},
635
- "combos": {0, 1},
636
  "anchor_gene": {None, str},
637
  "cell_states_to_model": {None, dict},
638
- "pickle_suffix": {None, str},
639
  }
640
-
641
  def __init__(
642
  self,
643
  mode="mixture_model",
@@ -652,42 +460,41 @@ class InSilicoPerturberStats:
652
  """
653
  Initialize in silico perturber stats generator.
654
 
655
- **Parameters:**
656
-
657
- mode : {"goal_state_shift", "vs_null", "mixture_model", "aggregate_data", "aggregate_gene_shifts"}
658
- | Type of stats.
659
- | "goal_state_shift": perturbation vs. random for desired cell state shift
660
- | "vs_null": perturbation vs. null from provided null distribution dataset
661
- | "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
662
- | "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
663
- | "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
664
  genes_perturbed : "all", list
665
- | Genes perturbed in isp experiment.
666
- | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
667
- | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
668
  combos : {0,1,2}
669
- | Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
670
  anchor_gene : None, str
671
- | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
672
- | For example, if combos=1 and anchor_gene="ENSG00000136574":
673
- | analyzes data for anchor gene perturbed in combination with each other gene.
674
- | However, if combos=0 and anchor_gene="ENSG00000136574":
675
- | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
676
  cell_states_to_model: None, dict
677
- | Cell states to model if testing perturbations that achieve goal state change.
678
- | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
679
- | state_key: key specifying name of column in .dataset that defines the start/goal states
680
- | start_state: value in the state_key column that specifies the start state
681
- | goal_state: value in the state_key column taht specifies the goal end state
682
- | alt_states: list of values in the state_key column that specify the alternate end states
683
- | For example: {"state_key": "disease",
684
- | "start_state": "dcm",
685
- | "goal_state": "nf",
686
- | "alt_states": ["hcm", "other1", "other2"]}
687
  token_dictionary_file : Path
688
- | Path to pickle file containing token dictionary (Ensembl ID:token).
689
  gene_name_id_dictionary_file : Path
690
- | Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
691
  """
692
 
693
  self.mode = mode
@@ -696,13 +503,13 @@ class InSilicoPerturberStats:
696
  self.anchor_gene = anchor_gene
697
  self.cell_states_to_model = cell_states_to_model
698
  self.pickle_suffix = pickle_suffix
699
-
700
  self.validate_options()
701
 
702
  # load token dictionary (Ensembl IDs:token)
703
  with open(token_dictionary_file, "rb") as f:
704
  self.gene_token_dict = pickle.load(f)
705
-
706
  # load gene name dictionary (gene name:Ensembl ID)
707
  with open(gene_name_id_dictionary_file, "rb") as f:
708
  self.gene_name_id_dict = pickle.load(f)
@@ -713,7 +520,7 @@ class InSilicoPerturberStats:
713
  self.anchor_token = self.gene_token_dict[self.anchor_gene]
714
 
715
  def validate_options(self):
716
- for attr_name, valid_options in self.valid_option_dict.items():
717
  attr_value = self.__dict__[attr_name]
718
  if type(attr_value) not in {list, dict}:
719
  if attr_name in {"anchor_gene"}:
@@ -722,40 +529,35 @@ class InSilicoPerturberStats:
722
  continue
723
  valid_type = False
724
  for option in valid_options:
725
- if (option in [str, int, list, dict]) and isinstance(
726
- attr_value, option
727
- ):
728
  valid_type = True
729
  break
730
  if not valid_type:
731
  logger.error(
732
- f"Invalid option for {attr_name}. "
733
  f"Valid options for {attr_name}: {valid_options}"
734
  )
735
  raise
736
-
737
  if self.cell_states_to_model is not None:
738
  if len(self.cell_states_to_model.items()) == 1:
739
  logger.warning(
740
- "The single value dictionary for cell_states_to_model will be "
741
- "replaced with a dictionary with named keys for start, goal, and alternate states. "
742
- "Please specify state_key, start_state, goal_state, and alt_states "
743
- "in the cell_states_to_model dictionary for future use. "
744
- "For example, cell_states_to_model={"
745
- "'state_key': 'disease', "
746
- "'start_state': 'dcm', "
747
- "'goal_state': 'nf', "
748
- "'alt_states': ['hcm', 'other1', 'other2']}"
749
  )
750
- for key, value in self.cell_states_to_model.items():
751
  if (len(value) == 3) and isinstance(value, tuple):
752
- if (
753
- isinstance(value[0], list)
754
- and isinstance(value[1], list)
755
- and isinstance(value[2], list)
756
- ):
757
  if len(value[0]) == 1 and len(value[1]) == 1:
758
- all_values = value[0] + value[1] + value[2]
759
  if len(all_values) == len(set(all_values)):
760
  continue
761
  # reformat to the new named key format
@@ -764,176 +566,140 @@ class InSilicoPerturberStats:
764
  "state_key": list(self.cell_states_to_model.keys())[0],
765
  "start_state": state_values[0][0],
766
  "goal_state": state_values[1][0],
767
- "alt_states": state_values[2:][0],
768
  }
769
- elif set(self.cell_states_to_model.keys()) == {
770
- "state_key",
771
- "start_state",
772
- "goal_state",
773
- "alt_states",
774
- }:
775
- if (
776
- (self.cell_states_to_model["state_key"] is None)
777
- or (self.cell_states_to_model["start_state"] is None)
778
- or (self.cell_states_to_model["goal_state"] is None)
779
- ):
780
  logger.error(
781
- "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
782
- )
783
  raise
784
-
785
- if (
786
- self.cell_states_to_model["start_state"]
787
- == self.cell_states_to_model["goal_state"]
788
- ):
789
- logger.error("All states must be unique.")
790
  raise
791
 
792
  if self.cell_states_to_model["alt_states"] is not None:
793
- if not isinstance(self.cell_states_to_model["alt_states"], list):
794
  logger.error(
795
  "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
796
  )
797
  raise
798
- if len(self.cell_states_to_model["alt_states"]) != len(
799
- set(self.cell_states_to_model["alt_states"])
800
- ):
801
- logger.error("All states must be unique.")
802
  raise
803
 
804
- elif set(self.cell_states_to_model.keys()) == {
805
- "state_key",
806
- "start_state",
807
- "goal_state",
808
- }:
809
- self.cell_states_to_model["alt_states"] = []
810
  else:
811
  logger.error(
812
- "cell_states_to_model must only have the following four keys: "
813
- "'state_key', 'start_state', 'goal_state', 'alt_states'."
814
- "For example, cell_states_to_model={"
815
- "'state_key': 'disease', "
816
- "'start_state': 'dcm', "
817
- "'goal_state': 'nf', "
818
- "'alt_states': ['hcm', 'other1', 'other2']}"
819
  )
820
  raise
821
 
822
  if self.anchor_gene is not None:
823
  self.anchor_gene = None
824
  logger.warning(
825
- "anchor_gene set to None. "
826
- "Currently, anchor gene not available "
827
- "when modeling multiple cell states."
828
- )
829
-
830
  if self.combos > 0:
831
  if self.anchor_gene is None:
832
  logger.error(
833
- "Currently, stats are only supported for combination "
834
- "in silico perturbation run with anchor gene. Please add "
835
- "anchor gene when using with combos > 0. "
836
- )
837
  raise
838
-
839
  if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
840
  logger.error(
841
- "Mixture model mode requires multiple gene perturbations to fit model "
842
- "so is incompatible with a single grouped perturbation."
843
- )
844
  raise
845
  if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
846
  logger.error(
847
- "Simple data aggregation mode is for single perturbation in multiple cells "
848
- "so is incompatible with a genes_perturbed being 'all'."
849
- )
850
- raise
851
-
852
- def get_stats(
853
- self,
854
- input_data_directory,
855
- null_dist_data_directory,
856
- output_directory,
857
- output_prefix,
858
- null_dict_list=None,
859
- ):
860
  """
861
  Get stats for in silico perturbation data and save as results in output_directory.
862
 
863
- **Parameters:**
864
-
865
  input_data_directory : Path
866
- | Path to directory containing cos_sim dictionary inputs
867
  null_dist_data_directory : Path
868
- | Path to directory containing null distribution cos_sim dictionary inputs
869
  output_directory : Path
870
- | Path to directory where perturbation data will be saved as .csv
871
  output_prefix : str
872
- | Prefix for output .csv
873
- null_dict_list: list[dict]
874
- | List of loaded null distribution dictionary if more than one comparison vs. the null is to be performed
875
-
876
- **Outputs:**
877
-
878
  Definition of possible columns in .csv output file.
879
-
880
- | Of note, not all columns will be present in all output files.
881
- | Some columns are specific to particular perturbation modes.
882
-
883
- | "Gene": gene token
884
- | "Gene_name": gene name
885
- | "Ensembl_ID": gene Ensembl ID
886
- | "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
887
- | "Sig": 1 if FDR<0.05, otherwise 0
888
-
889
- | "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
890
- | "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
891
- | "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
892
- | pvalue compares shift caused by perturbing given gene compared to random genes
893
- | "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
894
- | pvalue compares shift caused by perturbing given gene compared to random genes
895
- | "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
896
- | "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
897
-
898
- | "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
899
- | "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
900
- | "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
901
- | (i.e. "Test_avg_shift" minus "Null_avg_shift")
902
- | "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
903
- | "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
904
- | "N_Detections_test": "N_Detections" in cells from test distribution
905
- | "N_Detections_null": "N_Detections" in cells from null distribution
906
-
907
- | "Anchor_shift": cosine shift in response to given perturbation of anchor gene
908
- | "Test_token_shift": cosine shift in response to given perturbation of test gene
909
- | "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
910
- | "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
911
- | "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
912
- | (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
913
- | "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
914
- | 1: within impact component; 0: not within impact component
915
- | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
916
-
917
- | In case of aggregating gene shifts:
918
- | "Perturbed": ID(s) of gene(s) being perturbed
919
- | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
920
- | "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
921
- | "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
922
  """
923
 
924
- if self.mode not in [
925
- "goal_state_shift",
926
- "vs_null",
927
- "mixture_model",
928
- "aggregate_data",
929
- "aggregate_gene_shifts",
930
- ]:
931
  logger.error(
932
- "Currently, only modes available are stats for goal_state_shift, "
933
- "vs_null (comparing to null distribution), "
934
- "mixture_model (fitting mixture model for perturbations with or without impact), "
935
- "and aggregating data for single perturbations or for gene embedding shifts."
936
- )
937
  raise
938
 
939
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
@@ -942,107 +708,45 @@ class InSilicoPerturberStats:
942
  # obtain total gene list
943
  if (self.combos == 0) and (self.anchor_token is not None):
944
  # cos sim data for effect of gene perturbation on the embedding of each other gene
945
- dict_list = read_dictionaries(
946
- input_data_directory,
947
- "gene",
948
- self.anchor_token,
949
- self.cell_states_to_model,
950
- self.pickle_suffix,
951
- )
952
  gene_list = get_gene_list(dict_list, "gene")
953
- elif (
954
- (self.combos == 0)
955
- and (self.anchor_token is None)
956
- and (self.mode == "aggregate_gene_shifts")
957
- ):
958
- dict_list = read_dictionaries(
959
- input_data_directory,
960
- "gene",
961
- self.anchor_token,
962
- self.cell_states_to_model,
963
- self.pickle_suffix,
964
- )
965
- gene_list = get_gene_list(dict_list, "cell")
966
  else:
967
  # cos sim data for effect of gene perturbation on the embedding of each cell
968
- dict_list = read_dictionaries(
969
- input_data_directory,
970
- "cell",
971
- self.anchor_token,
972
- self.cell_states_to_model,
973
- self.pickle_suffix,
974
- )
975
  gene_list = get_gene_list(dict_list, "cell")
976
-
977
  # initiate results dataframe
978
- cos_sims_df_initial = pd.DataFrame(
979
- {
980
- "Gene": gene_list,
981
- "Gene_name": [self.token_to_gene_name(item) for item in gene_list],
982
- "Ensembl_ID": [
983
- token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
984
- if self.genes_perturbed != "all"
985
- else self.gene_token_id_dict[genes[1]]
986
- if isinstance(genes, tuple)
987
- else self.gene_token_id_dict[genes]
988
- for genes in gene_list
989
- ],
990
- },
991
- index=[i for i in range(len(gene_list))],
992
- )
993
 
994
  if self.mode == "goal_state_shift":
995
- cos_sims_df = isp_stats_to_goal_state(
996
- cos_sims_df_initial,
997
- dict_list,
998
- self.cell_states_to_model,
999
- self.genes_perturbed,
1000
- )
1001
-
1002
  elif self.mode == "vs_null":
1003
  if null_dict_list is None:
1004
- null_dict_list = read_dictionaries(
1005
- null_dist_data_directory,
1006
- "cell",
1007
- self.anchor_token,
1008
- self.cell_states_to_model,
1009
- self.pickle_suffix,
1010
- )
1011
- cos_sims_df = isp_stats_vs_null(
1012
- cos_sims_df_initial, dict_list, null_dict_list
1013
- )
1014
 
1015
  elif self.mode == "mixture_model":
1016
- cos_sims_df = isp_stats_mixture_model(
1017
- cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1018
- )
1019
-
1020
  elif self.mode == "aggregate_data":
1021
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
1022
 
1023
- elif self.mode == "aggregate_gene_shifts":
1024
- cos_sims_df = isp_aggregate_gene_shifts(
1025
- cos_sims_df_initial,
1026
- dict_list,
1027
- self.gene_token_id_dict,
1028
- self.gene_id_name_dict,
1029
- )
1030
-
1031
  # save perturbation stats to output_path
1032
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
1033
  cos_sims_df.to_csv(output_path)
1034
 
1035
  def token_to_gene_name(self, item):
1036
- if np.issubdtype(type(item), np.integer):
1037
- return self.gene_id_name_dict.get(
1038
- self.gene_token_id_dict.get(item, np.nan), np.nan
1039
- )
1040
- if isinstance(item, tuple):
1041
- return tuple(
1042
- [
1043
- self.gene_id_name_dict.get(
1044
- self.gene_token_id_dict.get(i, np.nan), np.nan
1045
- )
1046
- for i in item
1047
- ]
1048
- )
 
1
  """
2
  Geneformer in silico perturber stats generator.
3
 
4
+ Usage:
5
+ from geneformer import InSilicoPerturberStats
6
+ ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
+ combos=0,
8
+ anchor_gene=None,
9
+ cell_states_to_model={"state_key": "disease",
10
+ "start_state": "dcm",
11
+ "goal_state": "nf",
12
+ "alt_states": ["hcm", "other1", "other2"]})
13
+ ispstats.get_stats("path/to/input_data",
14
+ None,
15
+ "path/to/output_directory",
16
+ "output_prefix")
 
 
 
 
 
 
 
17
  """
18
 
19
 
 
20
  import os
21
+ import logging
 
 
 
22
  import numpy as np
23
  import pandas as pd
24
+ import pickle
25
+ import random
26
  import statsmodels.stats.multitest as smt
27
+ from pathlib import Path
28
  from scipy.stats import ranksums
29
  from sklearn.mixture import GaussianMixture
30
+ from tqdm.auto import trange, tqdm
31
+
32
+ from .perturber_helpers import flatten_list
33
 
 
34
  from .tokenizer import TOKEN_DICTIONARY_FILE
35
 
36
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
37
 
38
  logger = logging.getLogger(__name__)
39
 
 
40
  # invert dictionary keys/values
41
  def invert_dict(dictionary):
42
  return {v: k for k, v in dictionary.items()}
43
 
 
44
  def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
45
  if cell_or_gene_emb == "cell":
46
+ cell_emb_dict = {k: v for k,
47
+ v in cos_sims_dict.items() if v and "cell_emb" in k}
 
48
  return [cell_emb_dict]
49
  elif cell_or_gene_emb == "gene":
50
+ gene_emb_dict = {k: v for k,
51
+ v in cos_sims_dict.items() if v and anchor_token == k[0]}
 
 
 
 
52
  return [gene_emb_dict]
53
 
54
 
55
+ def recursive_search_dir(dir, pickle_suffix):
56
+
57
+
58
  # read raw dictionary files
59
+ def read_dictionaries(input_data_directory,
60
+ cell_or_gene_emb,
61
+ anchor_token,
62
+ cell_states_to_model,
63
+ pickle_suffix,
64
+ recursive=False):
65
+
66
  file_found = False
67
  file_path_list = []
68
  if cell_states_to_model is None:
69
  dict_list = []
70
  else:
71
+ state_dict = {state: [] for state in cell_states_to_model}
72
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  for file in os.listdir(input_data_directory):
74
+ # process only _raw.pickle files
75
  if file.endswith(pickle_suffix):
76
  file_found = True
77
  file_path_list += [f"{input_data_directory}/{file}"]
78
  for file_path in tqdm(file_path_list):
79
+ with open(file_path, 'rb') as fp:
80
  cos_sims_dict = pickle.load(fp)
81
  if cell_states_to_model is None:
82
  dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
83
  else:
84
+ for state in cell_states_to_model:
85
+ state_dict[state] += read_dict(cos_sims_dict[state], cell_or_gene_emb, anchor_token)
 
 
 
 
 
 
 
86
  if not file_found:
87
  logger.error(
88
+ f"No raw data for processing found within provided directory. " \
89
+ "Please ensure data files end with '{pickle_suffix}'.")
 
90
  raise
91
  if cell_states_to_model is None:
92
  return dict_list
93
  else:
94
  return state_dict
95
 
 
96
  # get complete gene list
97
+ def get_gene_list(dict_list,mode):
98
  if mode == "cell":
99
  position = 0
100
  elif mode == "gene":
101
  position = 1
102
  gene_set = set()
103
+ for dict_i in dict_list:
104
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
 
 
 
 
 
 
 
 
 
 
105
  gene_list = list(gene_set)
106
  if mode == "gene":
107
  gene_list.remove("cell_emb")
108
  gene_list.sort()
109
  return gene_list
110
 
 
111
  def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
112
  try:
113
  return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
114
+ except TypeError as te:
115
+ return tuple(gene_token_id_dict.get(token_tuple, np.nan))
 
116
 
117
  def n_detections(token, dict_list, mode, anchor_token):
118
  cos_sim_megalist = []
119
  for dict_i in dict_list:
120
  if mode == "cell":
121
+ cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
122
  elif mode == "gene":
123
+ cos_sim_megalist += dict_i.get((anchor_token, token),[])
124
  return len(cos_sim_megalist)
125
 
 
126
  def get_fdr(pvalues):
127
  return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
128
 
 
129
  def get_impact_component(test_value, gaussian_mixture_model):
130
  impact_border = gaussian_mixture_model.means_[0][0]
131
  nonimpact_border = gaussian_mixture_model.means_[1][0]
 
141
  impact_component = 1
142
  return impact_component
143
 
 
144
  # aggregate data for single perturbation in multiple cells
145
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
146
+ names=["Cosine_shift"]
147
  cos_sims_full_df = pd.DataFrame(columns=names)
148
 
149
  cos_shift_data = []
150
  token = cos_sims_df["Gene"][0]
151
  for dict_i in dict_list:
152
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
153
  cos_sims_full_df["Cosine_shift"] = cos_shift_data
154
+ return cos_sims_full_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
157
+ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
158
+ cell_state_key = cell_states_to_model["start_state"]
159
+ if ("alt_states" not in cell_states_to_model.keys()) \
160
+ or (len(cell_states_to_model["alt_states"]) == 0) \
161
+ or (cell_states_to_model["alt_states"] == [None]):
 
 
 
162
  alt_end_state_exists = False
163
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
 
 
164
  alt_end_state_exists = True
165
+
166
  # for single perturbation in multiple cells, there are no random perturbations to compare to
167
  if genes_perturbed != "all":
168
+ names=["Shift_to_goal_end",
169
+ "Shift_to_alt_end"]
170
+ if alt_end_state_exists == False:
171
+ names.remove("Shift_to_alt_end")
172
+ cos_sims_full_df = pd.DataFrame(columns=names)
173
+
174
+ cos_shift_data = []
175
  token = cos_sims_df["Gene"][0]
176
+ for dict_i in dict_list:
177
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
178
+ if alt_end_state_exists == False:
179
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data]
180
+ if alt_end_state_exists == True:
181
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
182
+ cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
183
+
 
 
 
 
 
 
184
  # sort by shift to desired state
185
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"],
186
+ ascending=[False])
187
+ return cos_sims_full_df
188
+
 
189
  elif genes_perturbed == "all":
190
+ random_tuples = []
 
 
 
 
191
  for i in trange(cos_sims_df.shape[0]):
192
  token = cos_sims_df["Gene"][i]
193
+ for dict_i in dict_list:
194
+ random_tuples += dict_i.get((token, "cell_emb"),[])
195
+
196
+ if alt_end_state_exists == False:
197
+ goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
198
+ elif alt_end_state_exists == True:
199
+ goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
200
+ alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
201
 
202
  # downsample to improve speed of ranksums
203
  if len(goal_end_random_megalist) > 100_000:
204
  random.seed(42)
205
+ goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
206
+ if alt_end_state_exists == True:
207
+ if len(alt_end_random_megalist) > 100_000:
208
+ random.seed(42)
209
+ alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
210
+
211
+ names=["Gene",
212
+ "Gene_name",
213
+ "Ensembl_ID",
214
+ "Shift_to_goal_end",
215
+ "Shift_to_alt_end",
216
+ "Goal_end_vs_random_pval",
217
+ "Alt_end_vs_random_pval"]
218
+ if alt_end_state_exists == False:
219
+ names.remove("Shift_to_alt_end")
220
+ names.remove("Alt_end_vs_random_pval")
 
 
 
 
 
 
 
 
 
 
 
 
221
  cos_sims_full_df = pd.DataFrame(columns=names)
222
 
 
223
  for i in trange(cos_sims_df.shape[0]):
224
  token = cos_sims_df["Gene"][i]
225
  name = cos_sims_df["Gene_name"][i]
226
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
227
+ cos_shift_data = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ for dict_i in dict_list:
230
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ if alt_end_state_exists == False:
233
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
234
+ elif alt_end_state_exists == True:
235
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
236
+ alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
237
+ mean_alt_end = np.mean(alt_end_cos_sim_megalist)
238
+ pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
239
 
240
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
241
+ pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
242
+
243
+ if alt_end_state_exists == False:
244
+ data_i = [token,
245
+ name,
246
+ ensembl_id,
247
+ mean_goal_end,
248
+ pval_goal_end]
249
+ elif alt_end_state_exists == True:
250
+ data_i = [token,
251
+ name,
252
+ ensembl_id,
253
+ mean_goal_end,
254
+ mean_alt_end,
255
+ pval_goal_end,
256
+ pval_alt_end]
257
+
258
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
259
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
260
+
261
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
262
+ if alt_end_state_exists == True:
263
+ cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
264
 
265
  # quantify number of detections of each gene
266
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
267
+
268
+ # sort by shift to desired state\
269
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
270
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
271
+ "Shift_to_goal_end",
272
+ "Goal_end_FDR"],
273
+ ascending=[False,False,True])
274
+
 
 
 
 
275
  return cos_sims_full_df
276
 
 
277
  # stats comparing cos sim shifts of test perturbations vs null distribution
278
  def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
279
  cos_sims_full_df = cos_sims_df.copy()
280
 
281
  cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
282
  cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
283
+ cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
 
 
284
  cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
285
  cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
286
+ cos_sims_full_df["N_Detections_test"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
287
+ cos_sims_full_df["N_Detections_null"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
288
+
 
 
 
 
289
  for i in trange(cos_sims_df.shape[0]):
290
  token = cos_sims_df["Gene"][i]
291
  test_shifts = []
292
  null_shifts = []
293
+
294
  for dict_i in dict_list:
295
+ test_shifts += dict_i.get((token, "cell_emb"),[])
296
 
297
  for dict_i in null_dict_list:
298
+ null_shifts += dict_i.get((token, "cell_emb"),[])
299
+
300
  cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
301
  cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
302
+ cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(test_shifts)-np.mean(null_shifts)
303
+ cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(test_shifts,
304
+ null_shifts, nan_policy="omit").pvalue
 
 
 
305
  # remove nan values
306
+ cos_sims_full_df.Test_vs_null_pval = np.where(np.isnan(cos_sims_full_df.Test_vs_null_pval), 1, cos_sims_full_df.Test_vs_null_pval)
 
 
 
 
307
  cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
308
  cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
309
 
310
+ cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
311
+
312
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
313
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
314
+ "Test_vs_null_avg_shift",
315
+ "Test_vs_null_FDR"],
316
+ ascending=[False,False,True])
 
 
 
 
317
  return cos_sims_full_df
318
 
 
319
  # stats for identifying perturbations with largest effect within a given set of cells
320
  # fits a mixture model to 2 components (impact vs. non-impact) and
321
  # reports the most likely component for each test perturbation
322
  # Note: because assumes given perturbation has a consistent effect in the cells tested,
323
  # we recommend only using the mixture model strategy with uniform cell populations
324
  def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
325
+
326
+ names=["Gene",
327
+ "Gene_name",
328
+ "Ensembl_ID"]
329
+
330
  if combos == 0:
331
  names += ["Test_avg_shift"]
332
  elif combos == 1:
333
+ names += ["Anchor_shift",
334
+ "Test_token_shift",
335
+ "Sum_of_indiv_shifts",
336
+ "Combo_shift",
337
+ "Combo_minus_sum_shift"]
338
+
339
+ names += ["Impact_component",
340
+ "Impact_component_percent"]
 
341
 
342
  cos_sims_full_df = pd.DataFrame(columns=names)
343
  avg_values = []
344
  gene_names = []
345
+
346
  for i in trange(cos_sims_df.shape[0]):
347
  token = cos_sims_df["Gene"][i]
348
  name = cos_sims_df["Gene_name"][i]
349
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
350
  cos_shift_data = []
351
+
352
  for dict_i in dict_list:
353
  if (combos == 0) and (anchor_token is not None):
354
+ cos_shift_data += dict_i.get((anchor_token, token),[])
355
  else:
356
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
357
+
358
  # Extract values for current gene
359
  if combos == 0:
360
  test_values = cos_shift_data
361
  elif combos == 1:
362
  test_values = []
363
  for tup in cos_shift_data:
364
+ test_values.append(tup[2])
365
+
366
  if len(test_values) > 0:
367
  avg_value = np.mean(test_values)
368
  avg_values.append(avg_value)
369
  gene_names.append(name)
370
+
371
  # fit Gaussian mixture model to dataset of mean for each gene
372
  avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
373
  gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
374
+
375
  for i in trange(cos_sims_df.shape[0]):
376
  token = cos_sims_df["Gene"][i]
377
  name = cos_sims_df["Gene_name"][i]
 
380
 
381
  for dict_i in dict_list:
382
  if (combos == 0) and (anchor_token is not None):
383
+ cos_shift_data += dict_i.get((anchor_token, token),[])
384
  else:
385
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
386
+
387
  if combos == 0:
388
  mean_test = np.mean(cos_shift_data)
389
+ impact_components = [get_impact_component(value,gm) for value in cos_shift_data]
 
 
390
  elif combos == 1:
391
+ anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data]
392
+ token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data]
393
+ anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data]
394
+ combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data]
395
+ combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data]
 
 
 
 
 
 
 
 
 
 
396
 
397
  mean_anchor = np.mean(anchor_cos_sim_megalist)
398
  mean_token = np.mean(token_cos_sim_megalist)
399
  mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
400
  mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
401
  mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
402
+
403
+ impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist]
404
+
405
+ impact_component = get_impact_component(mean_test,gm)
406
+ impact_component_percent = np.mean(impact_components)*100
407
+
408
+ data_i = [token,
409
+ name,
410
+ ensembl_id]
 
411
  if combos == 0:
412
  data_i += [mean_test]
413
  elif combos == 1:
414
+ data_i += [mean_anchor,
415
+ mean_token,
416
+ mean_sum,
417
+ mean_test,
418
+ mean_combo_minus_sum]
419
+ data_i += [impact_component,
420
+ impact_component_percent]
421
+
422
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
423
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
424
+
 
425
  # quantify number of detections of each gene
426
+ cos_sims_full_df["N_Detections"] = [n_detections(i,
427
+ dict_list,
428
+ "gene",
429
+ anchor_token) for i in cos_sims_full_df["Gene"]]
430
+
431
  if combos == 0:
432
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
433
+ "Test_avg_shift"],
434
+ ascending=[False,True])
435
  elif combos == 1:
436
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
437
+ "Combo_minus_sum_shift"],
438
+ ascending=[False,True])
439
  return cos_sims_full_df
440
 
 
441
  class InSilicoPerturberStats:
442
  valid_option_dict = {
443
+ "mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"},
444
+ "combos": {0,1},
 
 
 
 
 
 
 
445
  "anchor_gene": {None, str},
446
  "cell_states_to_model": {None, dict},
447
+ "pickle_suffix": {None, str}
448
  }
 
449
  def __init__(
450
  self,
451
  mode="mixture_model",
 
460
  """
461
  Initialize in silico perturber stats generator.
462
 
463
+ Parameters
464
+ ----------
465
+ mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
466
+ Type of stats.
467
+ "goal_state_shift": perturbation vs. random for desired cell state shift
468
+ "vs_null": perturbation vs. null from provided null distribution dataset
469
+ "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
470
+ "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
 
471
  genes_perturbed : "all", list
472
+ Genes perturbed in isp experiment.
473
+ Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
474
+ Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
475
  combos : {0,1,2}
476
+ Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
477
  anchor_gene : None, str
478
+ ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
479
+ For example, if combos=1 and anchor_gene="ENSG00000136574":
480
+ analyzes data for anchor gene perturbed in combination with each other gene.
481
+ However, if combos=0 and anchor_gene="ENSG00000136574":
482
+ analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
483
  cell_states_to_model: None, dict
484
+ Cell states to model if testing perturbations that achieve goal state change.
485
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
486
+ state_key: key specifying name of column in .dataset that defines the start/goal states
487
+ start_state: value in the state_key column that specifies the start state
488
+ goal_state: value in the state_key column taht specifies the goal end state
489
+ alt_states: list of values in the state_key column that specify the alternate end states
490
+ For example: {"state_key": "disease",
491
+ "start_state": "dcm",
492
+ "goal_state": "nf",
493
+ "alt_states": ["hcm", "other1", "other2"]}
494
  token_dictionary_file : Path
495
+ Path to pickle file containing token dictionary (Ensembl ID:token).
496
  gene_name_id_dictionary_file : Path
497
+ Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
498
  """
499
 
500
  self.mode = mode
 
503
  self.anchor_gene = anchor_gene
504
  self.cell_states_to_model = cell_states_to_model
505
  self.pickle_suffix = pickle_suffix
506
+
507
  self.validate_options()
508
 
509
  # load token dictionary (Ensembl IDs:token)
510
  with open(token_dictionary_file, "rb") as f:
511
  self.gene_token_dict = pickle.load(f)
512
+
513
  # load gene name dictionary (gene name:Ensembl ID)
514
  with open(gene_name_id_dictionary_file, "rb") as f:
515
  self.gene_name_id_dict = pickle.load(f)
 
520
  self.anchor_token = self.gene_token_dict[self.anchor_gene]
521
 
522
  def validate_options(self):
523
+ for attr_name,valid_options in self.valid_option_dict.items():
524
  attr_value = self.__dict__[attr_name]
525
  if type(attr_value) not in {list, dict}:
526
  if attr_name in {"anchor_gene"}:
 
529
  continue
530
  valid_type = False
531
  for option in valid_options:
532
+ # not sure what the last check is for?
533
+ if isinstance(attr_value, option): # and (option in [int,list,dict]):
 
534
  valid_type = True
535
  break
536
  if not valid_type:
537
  logger.error(
538
+ f"Invalid option for {attr_name}. " \
539
  f"Valid options for {attr_name}: {valid_options}"
540
  )
541
  raise
542
+
543
  if self.cell_states_to_model is not None:
544
  if len(self.cell_states_to_model.items()) == 1:
545
  logger.warning(
546
+ "The single value dictionary for cell_states_to_model will be " \
547
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
548
+ "Please specify state_key, start_state, goal_state, and alt_states " \
549
+ "in the cell_states_to_model dictionary for future use. " \
550
+ "For example, cell_states_to_model={" \
551
+ "'state_key': 'disease', " \
552
+ "'start_state': 'dcm', " \
553
+ "'goal_state': 'nf', " \
554
+ "'alt_states': ['hcm', 'other1', 'other2']}"
555
  )
556
+ for key,value in self.cell_states_to_model.items():
557
  if (len(value) == 3) and isinstance(value, tuple):
558
+ if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
 
 
 
559
  if len(value[0]) == 1 and len(value[1]) == 1:
560
+ all_values = value[0]+value[1]+value[2]
561
  if len(all_values) == len(set(all_values)):
562
  continue
563
  # reformat to the new named key format
 
566
  "state_key": list(self.cell_states_to_model.keys())[0],
567
  "start_state": state_values[0][0],
568
  "goal_state": state_values[1][0],
569
+ "alt_states": state_values[2:][0]
570
  }
571
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
572
+ if (self.cell_states_to_model["state_key"] is None) \
573
+ or (self.cell_states_to_model["start_state"] is None) \
574
+ or (self.cell_states_to_model["goal_state"] is None):
 
 
 
 
 
 
 
575
  logger.error(
576
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
 
577
  raise
578
+
579
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
580
+ logger.error(
581
+ "All states must be unique.")
 
 
582
  raise
583
 
584
  if self.cell_states_to_model["alt_states"] is not None:
585
+ if type(self.cell_states_to_model["alt_states"]) is not list:
586
  logger.error(
587
  "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
588
  )
589
  raise
590
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
591
+ logger.error(
592
+ "All states must be unique.")
 
593
  raise
594
 
 
 
 
 
 
 
595
  else:
596
  logger.error(
597
+ "cell_states_to_model must only have the following four keys: " \
598
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
599
+ "For example, cell_states_to_model={" \
600
+ "'state_key': 'disease', " \
601
+ "'start_state': 'dcm', " \
602
+ "'goal_state': 'nf', " \
603
+ "'alt_states': ['hcm', 'other1', 'other2']}"
604
  )
605
  raise
606
 
607
  if self.anchor_gene is not None:
608
  self.anchor_gene = None
609
  logger.warning(
610
+ "anchor_gene set to None. " \
611
+ "Currently, anchor gene not available " \
612
+ "when modeling multiple cell states.")
613
+
 
614
  if self.combos > 0:
615
  if self.anchor_gene is None:
616
  logger.error(
617
+ "Currently, stats are only supported for combination " \
618
+ "in silico perturbation run with anchor gene. Please add " \
619
+ "anchor gene when using with combos > 0. ")
 
620
  raise
621
+
622
  if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
623
  logger.error(
624
+ "Mixture model mode requires multiple gene perturbations to fit model " \
625
+ "so is incompatible with a single grouped perturbation.")
 
626
  raise
627
  if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
628
  logger.error(
629
+ "Simple data aggregation mode is for single perturbation in multiple cells " \
630
+ "so is incompatible with a genes_perturbed being 'all'.")
631
+ raise
632
+
633
+ def get_stats(self,
634
+ input_data_directory,
635
+ null_dist_data_directory,
636
+ output_directory,
637
+ output_prefix,
638
+ null_dict_list=None,
639
+ recursive=False):
 
 
640
  """
641
  Get stats for in silico perturbation data and save as results in output_directory.
642
 
643
+ Parameters
644
+ ----------
645
  input_data_directory : Path
646
+ Path to directory containing cos_sim dictionary inputs
647
  null_dist_data_directory : Path
648
+ Path to directory containing null distribution cos_sim dictionary inputs
649
  output_directory : Path
650
+ Path to directory where perturbation data will be saved as .csv
651
  output_prefix : str
652
+ Prefix for output .csv
653
+ null_dict_list: dict
654
+ List of loaded null distribtion dictionary if more than one comparison vs. the null is to be performed
655
+
656
+ Outputs
657
+ ----------
658
  Definition of possible columns in .csv output file.
659
+
660
+ Of note, not all columns will be present in all output files.
661
+ Some columns are specific to particular perturbation modes.
662
+
663
+ "Gene": gene token
664
+ "Gene_name": gene name
665
+ "Ensembl_ID": gene Ensembl ID
666
+ "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
667
+ "Sig": 1 if FDR<0.05, otherwise 0
668
+
669
+ "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
670
+ "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
671
+ "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
672
+ pvalue compares shift caused by perturbing given gene compared to random genes
673
+ "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
674
+ pvalue compares shift caused by perturbing given gene compared to random genes
675
+ "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
676
+ "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
677
+
678
+ "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
679
+ "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
680
+ "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
681
+ (i.e. "Test_avg_shift" minus "Null_avg_shift")
682
+ "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
683
+ "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
684
+ "N_Detections_test": "N_Detections" in cells from test distribution
685
+ "N_Detections_null": "N_Detections" in cells from null distribution
686
+
687
+ "Anchor_shift": cosine shift in response to given perturbation of anchor gene
688
+ "Test_token_shift": cosine shift in response to given perturbation of test gene
689
+ "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
690
+ "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
691
+ "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
692
+ (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
693
+ "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
694
+ 1: within impact component; 0: not within impact component
695
+ "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
 
 
 
 
 
 
696
  """
697
 
698
+ if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]:
 
 
 
 
 
 
699
  logger.error(
700
+ "Currently, only modes available are stats for goal_state_shift, " \
701
+ "vs_null (comparing to null distribution), and " \
702
+ "mixture_model (fitting mixture model for perturbations with or without impact).")
 
 
703
  raise
704
 
705
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
 
708
  # obtain total gene list
709
  if (self.combos == 0) and (self.anchor_token is not None):
710
  # cos sim data for effect of gene perturbation on the embedding of each other gene
711
+ dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token, self.cell_states_to_model, self.pickle_suffix, recursive=recursive)
 
 
 
 
 
 
712
  gene_list = get_gene_list(dict_list, "gene")
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  else:
714
  # cos sim data for effect of gene perturbation on the embedding of each cell
715
+ dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token, self.cell_states_to_model, self.pickle_suffix, recursive=recursive)
 
 
 
 
 
 
716
  gene_list = get_gene_list(dict_list, "cell")
717
+
718
  # initiate results dataframe
719
+ cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
720
+ "Gene_name": [self.token_to_gene_name(item) \
721
+ for item in gene_list],
722
+ "Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \
723
+ if self.genes_perturbed != "all" else \
724
+ self.gene_token_id_dict[genes[1]] \
725
+ if isinstance(genes,tuple) else \
726
+ self.gene_token_id_dict[genes] \
727
+ for genes in gene_list]}, \
728
+ index=[i for i in range(len(gene_list))])
 
 
 
 
 
729
 
730
  if self.mode == "goal_state_shift":
731
+ cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed)
732
+
 
 
 
 
 
733
  elif self.mode == "vs_null":
734
  if null_dict_list is None:
735
+ null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token, self.cell_states_to_model, self.pickle_suffix)
736
+ cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
 
 
 
 
 
 
 
 
737
 
738
  elif self.mode == "mixture_model":
739
+ cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
740
+
 
 
741
  elif self.mode == "aggregate_data":
742
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
743
 
 
 
 
 
 
 
 
 
744
  # save perturbation stats to output_path
745
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
746
  cos_sims_df.to_csv(output_path)
747
 
748
  def token_to_gene_name(self, item):
749
+ if isinstance(item,int):
750
+ return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
751
+ if isinstance(item,tuple):
752
+ return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])