Fixed error with perturbing individual genes and updated ways to specify cell_states_to_model
Browse files- geneformer/in_silico_perturber.py +138 -34
geneformer/in_silico_perturber.py
CHANGED
@@ -105,6 +105,12 @@ def downsample_and_sort(data_shuffled, max_ncells):
|
|
105 |
data_sorted = data_subset.sort("length",reverse=True)
|
106 |
return data_sorted
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
109 |
example_cell.set_format(type="torch")
|
110 |
input_data = example_cell["input_ids"]
|
@@ -235,13 +241,15 @@ def get_cell_state_avg_embs(model,
|
|
235 |
num_proc):
|
236 |
|
237 |
model_input_size = get_model_input_size(model)
|
238 |
-
possible_states =
|
239 |
state_embs_dict = dict()
|
240 |
for possible_state in possible_states:
|
241 |
state_embs_list = []
|
|
|
242 |
|
243 |
def filter_states(example):
|
244 |
-
|
|
|
245 |
filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
|
246 |
total_batch_length = len(filtered_input_data_state)
|
247 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
@@ -254,6 +262,7 @@ def get_cell_state_avg_embs(model,
|
|
254 |
state_minibatch.set_format(type="torch")
|
255 |
|
256 |
input_data_minibatch = state_minibatch["input_ids"]
|
|
|
257 |
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
258 |
max_len,
|
259 |
pad_token_id,
|
@@ -271,8 +280,12 @@ def get_cell_state_avg_embs(model,
|
|
271 |
del input_data_minibatch
|
272 |
del state_embs_i
|
273 |
torch.cuda.empty_cache()
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
state_embs_dict[possible_state] = avg_state_emb
|
277 |
return state_embs_dict
|
278 |
|
@@ -291,7 +304,6 @@ def quant_cos_sims(model,
|
|
291 |
pad_token_id,
|
292 |
model_input_size,
|
293 |
nproc):
|
294 |
-
|
295 |
cos = torch.nn.CosineSimilarity(dim=2)
|
296 |
total_batch_length = len(perturbation_batch)
|
297 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
@@ -301,7 +313,7 @@ def quant_cos_sims(model,
|
|
301 |
comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
|
302 |
cos_sims = []
|
303 |
else:
|
304 |
-
possible_states =
|
305 |
cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
|
306 |
|
307 |
# measure length of each element in perturbation_batch
|
@@ -316,6 +328,7 @@ def quant_cos_sims(model,
|
|
316 |
|
317 |
# determine if need to pad or truncate batch
|
318 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
|
|
319 |
if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
|
320 |
needs_pad_or_trunc = True
|
321 |
else:
|
@@ -360,6 +373,7 @@ def quant_cos_sims(model,
|
|
360 |
# truncate to the (model input size - # tokens to overexpress) to ensure comparability
|
361 |
# since max input size of perturb batch will be reduced by # tokens to overexpress
|
362 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
|
|
363 |
original_minibatch_length_set = set(original_minibatch["length"])
|
364 |
if perturb_type == "overexpress":
|
365 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
@@ -385,7 +399,32 @@ def quant_cos_sims(model,
|
|
385 |
original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
|
386 |
else:
|
387 |
original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
# cosine similarity between original emb and batch items
|
390 |
if cell_states_to_model is None:
|
391 |
if perturb_group == False:
|
@@ -406,7 +445,9 @@ def quant_cos_sims(model,
|
|
406 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
|
407 |
minibatch_emb,
|
408 |
state_embs_dict[state],
|
409 |
-
perturb_group
|
|
|
|
|
410 |
del outputs
|
411 |
del minibatch_emb
|
412 |
if cell_states_to_model is None:
|
@@ -421,14 +462,40 @@ def quant_cos_sims(model,
|
|
421 |
return cos_sims_vs_alt_dict
|
422 |
|
423 |
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
424 |
-
def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group):
|
425 |
cos = torch.nn.CosineSimilarity(dim=2)
|
426 |
-
|
427 |
-
|
428 |
original_emb = original_emb[None, :]
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
433 |
|
434 |
def pad_list(input_ids, pad_token_id, max_len):
|
@@ -706,6 +773,12 @@ class InSilicoPerturber:
|
|
706 |
|
707 |
if self.cell_states_to_model is not None:
|
708 |
if len(self.cell_states_to_model.items()) == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
709 |
for key,value in self.cell_states_to_model.items():
|
710 |
if (len(value) == 3) and isinstance(value, tuple):
|
711 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
@@ -713,14 +786,48 @@ class InSilicoPerturber:
|
|
713 |
all_values = value[0]+value[1]+value[2]
|
714 |
if len(all_values) == len(set(all_values)):
|
715 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
716 |
else:
|
717 |
logger.error(
|
718 |
-
"
|
719 |
-
"
|
720 |
-
|
721 |
-
|
722 |
-
|
|
|
|
|
|
|
723 |
raise
|
|
|
724 |
if self.anchor_gene is not None:
|
725 |
self.anchor_gene = None
|
726 |
logger.warning(
|
@@ -770,6 +877,13 @@ class InSilicoPerturber:
|
|
770 |
if self.cell_states_to_model is None:
|
771 |
state_embs_dict = None
|
772 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
773 |
# get dictionary of average cell state embeddings for comparison
|
774 |
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
775 |
state_embs_dict = get_cell_state_avg_embs(model,
|
@@ -780,9 +894,9 @@ class InSilicoPerturber:
|
|
780 |
self.forward_batch_size,
|
781 |
self.nproc)
|
782 |
# filter for start state cells
|
783 |
-
start_state =
|
784 |
def filter_for_origin(example):
|
785 |
-
return example[
|
786 |
|
787 |
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
|
788 |
|
@@ -878,7 +992,6 @@ class InSilicoPerturber:
|
|
878 |
# or (perturbed_genes, "cell_emb") for avg cell emb change
|
879 |
cos_sims_data = cos_sims_data.to("cuda")
|
880 |
max_padded_len = cos_sims_data.shape[1]
|
881 |
-
|
882 |
for j in range(cos_sims_data.shape[0]):
|
883 |
# remove padding before mean pooling cell embedding
|
884 |
original_length = original_lengths[j]
|
@@ -900,21 +1013,13 @@ class InSilicoPerturber:
|
|
900 |
# update cos sims dict
|
901 |
# key is tuple of (perturbed_genes, "cell_emb")
|
902 |
# value is list of tuples of cos sims for cell_states_to_model
|
903 |
-
origin_state_key =
|
904 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
905 |
for j in range(cos_sims_origin.shape[0]):
|
906 |
-
original_length = original_lengths[j]
|
907 |
-
max_padded_len = cos_sims_origin.shape[1]
|
908 |
-
indices_removed = indices_to_perturb[j]
|
909 |
-
padding_to_remove = max_padded_len - (original_length \
|
910 |
-
- len(self.tokens_to_perturb) \
|
911 |
-
- len(indices_removed))
|
912 |
data_list = []
|
913 |
for data in list(cos_sims_data.values()):
|
914 |
data_item = data.to("cuda")
|
915 |
-
|
916 |
-
cell_data = torch.mean(nonpadding_data_item).item()
|
917 |
-
data_list += [cell_data]
|
918 |
cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
|
919 |
|
920 |
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
@@ -987,7 +1092,7 @@ class InSilicoPerturber:
|
|
987 |
# update cos sims dict
|
988 |
# key is tuple of (perturbed_gene, "cell_emb")
|
989 |
# value is list of tuples of cos sims for cell_states_to_model
|
990 |
-
origin_state_key =
|
991 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
992 |
|
993 |
for j in range(cos_sims_origin.shape[0]):
|
@@ -1109,4 +1214,3 @@ class InSilicoPerturber:
|
|
1109 |
# save remainder cells
|
1110 |
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
1111 |
pickle.dump(cos_sims_dict, fp)
|
1112 |
-
|
|
|
105 |
data_sorted = data_subset.sort("length",reverse=True)
|
106 |
return data_sorted
|
107 |
|
108 |
+
def get_possible_states(cell_states_to_model):
|
109 |
+
if list(cell_states_to_model.values())[3] is not None:
|
110 |
+
return list(cell_states_to_model.values())[1:3] + list(cell_states_to_model.values())[3]
|
111 |
+
else:
|
112 |
+
return list(cell_states_to_model.values())[1:3]
|
113 |
+
|
114 |
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
115 |
example_cell.set_format(type="torch")
|
116 |
input_data = example_cell["input_ids"]
|
|
|
241 |
num_proc):
|
242 |
|
243 |
model_input_size = get_model_input_size(model)
|
244 |
+
possible_states = get_possible_states(cell_states_to_model)
|
245 |
state_embs_dict = dict()
|
246 |
for possible_state in possible_states:
|
247 |
state_embs_list = []
|
248 |
+
original_lens = []
|
249 |
|
250 |
def filter_states(example):
|
251 |
+
state_key = cell_states_to_model["state_key"]
|
252 |
+
return example[state_key] in possible_state
|
253 |
filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
|
254 |
total_batch_length = len(filtered_input_data_state)
|
255 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
|
|
262 |
state_minibatch.set_format(type="torch")
|
263 |
|
264 |
input_data_minibatch = state_minibatch["input_ids"]
|
265 |
+
original_lens += [tensor.numel() for tensor in input_data_minibatch]
|
266 |
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
267 |
max_len,
|
268 |
pad_token_id,
|
|
|
280 |
del input_data_minibatch
|
281 |
del state_embs_i
|
282 |
torch.cuda.empty_cache()
|
283 |
+
|
284 |
+
# import here to avoid circular imports
|
285 |
+
from .emb_extractor import mean_nonpadding_embs
|
286 |
+
state_embs = torch.cat(state_embs_list)
|
287 |
+
avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
|
288 |
+
avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
|
289 |
state_embs_dict[possible_state] = avg_state_emb
|
290 |
return state_embs_dict
|
291 |
|
|
|
304 |
pad_token_id,
|
305 |
model_input_size,
|
306 |
nproc):
|
|
|
307 |
cos = torch.nn.CosineSimilarity(dim=2)
|
308 |
total_batch_length = len(perturbation_batch)
|
309 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
|
|
313 |
comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
|
314 |
cos_sims = []
|
315 |
else:
|
316 |
+
possible_states = get_possible_states(cell_states_to_model)
|
317 |
cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
|
318 |
|
319 |
# measure length of each element in perturbation_batch
|
|
|
328 |
|
329 |
# determine if need to pad or truncate batch
|
330 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
331 |
+
minibatch_lengths = perturbation_minibatch["length"]
|
332 |
if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
|
333 |
needs_pad_or_trunc = True
|
334 |
else:
|
|
|
373 |
# truncate to the (model input size - # tokens to overexpress) to ensure comparability
|
374 |
# since max input size of perturb batch will be reduced by # tokens to overexpress
|
375 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
376 |
+
original_minibatch_lengths = original_minibatch["length"]
|
377 |
original_minibatch_length_set = set(original_minibatch["length"])
|
378 |
if perturb_type == "overexpress":
|
379 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
|
|
399 |
original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
|
400 |
else:
|
401 |
original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
|
402 |
+
|
403 |
+
# remove perturbed index before calculating the cos sims
|
404 |
+
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
405 |
+
# indices_to_remove is list of indices to remove
|
406 |
+
gene_dim -= 1 # removing a dim in calling the function
|
407 |
+
indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
|
408 |
+
num_dims = emb.dim()
|
409 |
+
emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
|
410 |
+
sliced_emb = emb[emb_slice]
|
411 |
+
return sliced_emb
|
412 |
+
|
413 |
+
# this could probably be optimized
|
414 |
+
gene_dim = 1
|
415 |
+
|
416 |
+
# current there's the case if a gene is not expressed and is being overexpressed,
|
417 |
+
# the dimensions will be thrown off --> not removing indices to get around that issue
|
418 |
+
# not sure what's the best way to handle it
|
419 |
+
if perturb_type != "overexpress":
|
420 |
+
original_minibatch_emb = torch.stack([
|
421 |
+
remove_indices_from_emb(original_minibatch_emb[i, :, :], idx, gene_dim) for
|
422 |
+
i, idx in enumerate(indices_to_perturb)
|
423 |
+
])
|
424 |
+
|
425 |
+
# do the averaging here
|
426 |
+
|
427 |
+
|
428 |
# cosine similarity between original emb and batch items
|
429 |
if cell_states_to_model is None:
|
430 |
if perturb_group == False:
|
|
|
445 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
|
446 |
minibatch_emb,
|
447 |
state_embs_dict[state],
|
448 |
+
perturb_group,
|
449 |
+
torch.tensor(original_minibatch_lengths, device="cuda"),
|
450 |
+
torch.tensor(minibatch_lengths, device="cuda"))
|
451 |
del outputs
|
452 |
del minibatch_emb
|
453 |
if cell_states_to_model is None:
|
|
|
462 |
return cos_sims_vs_alt_dict
|
463 |
|
464 |
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
465 |
+
def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group, original_minibatch_lengths = None, minibatch_lengths = None,):
|
466 |
cos = torch.nn.CosineSimilarity(dim=2)
|
467 |
+
if not perturb_group:
|
468 |
+
original_emb = torch.mean(original_emb,dim=0,keepdim=True)
|
469 |
original_emb = original_emb[None, :]
|
470 |
+
origin_v_end = torch.squeeze(cos(original_emb, alt_emb))
|
471 |
+
else:
|
472 |
+
if original_emb.size() != minibatch_emb.size():
|
473 |
+
logger.error(
|
474 |
+
f"Embeddings are not the same dimensions. " \
|
475 |
+
f"original_emb is {original_emb.size()}. " \
|
476 |
+
f"minibatch_emb is {minibatch_emb.size()}. "
|
477 |
+
)
|
478 |
+
raise
|
479 |
+
from .emb_extractor import mean_nonpadding_embs
|
480 |
+
|
481 |
+
if original_minibatch_lengths is not None:
|
482 |
+
original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
|
483 |
+
# not sure if the else is necessary, but keeping it here in case
|
484 |
+
else:
|
485 |
+
original_emb = torch.mean(original_emb,dim=1,keepdim=True)
|
486 |
+
|
487 |
+
alt_emb = torch.unsqueeze(alt_emb, 1)
|
488 |
+
origin_v_end = cos(original_emb, alt_emb)
|
489 |
+
origin_v_end = torch.squeeze(origin_v_end)
|
490 |
+
|
491 |
+
if minibatch_lengths is not None:
|
492 |
+
perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
|
493 |
+
else:
|
494 |
+
perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
|
495 |
+
|
496 |
+
perturb_v_end = cos(perturb_emb, alt_emb)
|
497 |
+
perturb_v_end = torch.squeeze(perturb_v_end)
|
498 |
+
|
499 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
500 |
|
501 |
def pad_list(input_ids, pad_token_id, max_len):
|
|
|
773 |
|
774 |
if self.cell_states_to_model is not None:
|
775 |
if len(self.cell_states_to_model.items()) == 1:
|
776 |
+
logger.warning(
|
777 |
+
"The single value dictionary for cell_states_to_model will be " \
|
778 |
+
"replaced with explicitly modeling start and end states. " \
|
779 |
+
"Please specify state_key, start_state, end_state, and alt_states " \
|
780 |
+
"in the cell_states_to_model dictionary for future use."
|
781 |
+
)
|
782 |
for key,value in self.cell_states_to_model.items():
|
783 |
if (len(value) == 3) and isinstance(value, tuple):
|
784 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
|
786 |
all_values = value[0]+value[1]+value[2]
|
787 |
if len(all_values) == len(set(all_values)):
|
788 |
continue
|
789 |
+
# reformat to the new format
|
790 |
+
state_values = flatten_list(list(self.cell_states_to_model.values()))
|
791 |
+
self.cell_states_to_model = {
|
792 |
+
"state_key": list(self.cell_states_to_model.keys())[0],
|
793 |
+
"start_state": state_values[0][0],
|
794 |
+
"goal_state": state_values[1][0],
|
795 |
+
"alt_states": state_values[2:][0]
|
796 |
+
}
|
797 |
+
elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
|
798 |
+
if self.cell_states_to_model["start_state"] is None or self.cell_states_to_model["goal_state"] is None:
|
799 |
+
logger.error(
|
800 |
+
"Please specify 'start_state' and 'goal_state' in cell_states_to_model.")
|
801 |
+
raise
|
802 |
+
|
803 |
+
if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
|
804 |
+
logger.error(
|
805 |
+
"All states must be unique.")
|
806 |
+
raise
|
807 |
+
|
808 |
+
if self.cell_states_to_model["alt_states"] is not None:
|
809 |
+
if type(self.cell_states_to_model["alt_states"]) is not list:
|
810 |
+
logger.error(
|
811 |
+
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
812 |
+
)
|
813 |
+
raise
|
814 |
+
if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
|
815 |
+
logger.error(
|
816 |
+
"All states must be unique.")
|
817 |
+
raise
|
818 |
+
|
819 |
else:
|
820 |
logger.error(
|
821 |
+
"states_to_model must only have the following four keys: 'state_key', 'start_state', 'goal_state', 'alt_states'." \
|
822 |
+
"For example, cell_states_to_model={ \
|
823 |
+
'state_key': 'disease', \
|
824 |
+
'start_state': 'dcm', \
|
825 |
+
'goal_state': 'nf'', \
|
826 |
+
'alt_states': ['hcm', 'other1', 'other2'] \
|
827 |
+
}"
|
828 |
+
)
|
829 |
raise
|
830 |
+
|
831 |
if self.anchor_gene is not None:
|
832 |
self.anchor_gene = None
|
833 |
logger.warning(
|
|
|
877 |
if self.cell_states_to_model is None:
|
878 |
state_embs_dict = None
|
879 |
else:
|
880 |
+
# make sure that all states are valid; save time on filtering
|
881 |
+
state_name = self.cell_states_to_model["state_key"]
|
882 |
+
for value in get_possible_states(self.cell_states_to_model):
|
883 |
+
if value not in filtered_input_data[state_name]:
|
884 |
+
logger.error(
|
885 |
+
f"{value} is not a valid value in {state_name}.")
|
886 |
+
raise
|
887 |
# get dictionary of average cell state embeddings for comparison
|
888 |
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
889 |
state_embs_dict = get_cell_state_avg_embs(model,
|
|
|
894 |
self.forward_batch_size,
|
895 |
self.nproc)
|
896 |
# filter for start state cells
|
897 |
+
start_state = self.cell_states_to_model["start_state"]
|
898 |
def filter_for_origin(example):
|
899 |
+
return example[state_name] in [start_state]
|
900 |
|
901 |
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
|
902 |
|
|
|
992 |
# or (perturbed_genes, "cell_emb") for avg cell emb change
|
993 |
cos_sims_data = cos_sims_data.to("cuda")
|
994 |
max_padded_len = cos_sims_data.shape[1]
|
|
|
995 |
for j in range(cos_sims_data.shape[0]):
|
996 |
# remove padding before mean pooling cell embedding
|
997 |
original_length = original_lengths[j]
|
|
|
1013 |
# update cos sims dict
|
1014 |
# key is tuple of (perturbed_genes, "cell_emb")
|
1015 |
# value is list of tuples of cos sims for cell_states_to_model
|
1016 |
+
origin_state_key = self.cell_states_to_model["start_state"]
|
1017 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
1018 |
for j in range(cos_sims_origin.shape[0]):
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
data_list = []
|
1020 |
for data in list(cos_sims_data.values()):
|
1021 |
data_item = data.to("cuda")
|
1022 |
+
data_list += [data_item]
|
|
|
|
|
1023 |
cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
|
1024 |
|
1025 |
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
|
|
1092 |
# update cos sims dict
|
1093 |
# key is tuple of (perturbed_gene, "cell_emb")
|
1094 |
# value is list of tuples of cos sims for cell_states_to_model
|
1095 |
+
origin_state_key = self.cell_states_to_model["start_state"]
|
1096 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
1097 |
|
1098 |
for j in range(cos_sims_origin.shape[0]):
|
|
|
1214 |
# save remainder cells
|
1215 |
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
1216 |
pickle.dump(cos_sims_dict, fp)
|
|