Upload in_silico_perturber.py
Browse filesFix bug related to padding when overexpressing genes
geneformer/in_silico_perturber.py
CHANGED
@@ -40,7 +40,7 @@ import pickle
|
|
40 |
from collections import defaultdict
|
41 |
|
42 |
import torch
|
43 |
-
from datasets import Dataset
|
44 |
from multiprocess import set_start_method
|
45 |
from tqdm.auto import trange
|
46 |
|
@@ -48,7 +48,9 @@ from . import TOKEN_DICTIONARY_FILE
|
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
|
51 |
-
|
|
|
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
@@ -794,6 +796,8 @@ class InSilicoPerturber:
|
|
794 |
return example
|
795 |
|
796 |
total_batch_length = len(filtered_input_data)
|
|
|
|
|
797 |
if self.cell_states_to_model is None:
|
798 |
cos_sims_dict = defaultdict(list)
|
799 |
else:
|
@@ -878,7 +882,7 @@ class InSilicoPerturber:
|
|
878 |
)
|
879 |
|
880 |
##### CLS and Gene Embedding Mode #####
|
881 |
-
elif self.emb_mode == "cls_and_gene":
|
882 |
full_original_emb = get_embs(
|
883 |
model,
|
884 |
minibatch,
|
@@ -891,6 +895,7 @@ class InSilicoPerturber:
|
|
891 |
silent=True,
|
892 |
)
|
893 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
|
|
894 |
# remove indices that were perturbed
|
895 |
original_emb = pu.remove_perturbed_indices_set(
|
896 |
full_original_emb,
|
@@ -899,6 +904,7 @@ class InSilicoPerturber:
|
|
899 |
self.tokens_to_perturb,
|
900 |
minibatch["length"],
|
901 |
)
|
|
|
902 |
full_perturbation_emb = get_embs(
|
903 |
model,
|
904 |
perturbation_batch,
|
@@ -910,7 +916,7 @@ class InSilicoPerturber:
|
|
910 |
summary_stat=None,
|
911 |
silent=True,
|
912 |
)
|
913 |
-
|
914 |
# remove special tokens and padding
|
915 |
original_emb = original_emb[:, 1:-1, :]
|
916 |
if self.perturb_type == "overexpress":
|
@@ -921,9 +927,25 @@ class InSilicoPerturber:
|
|
921 |
perturbation_emb = full_perturbation_emb[
|
922 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
923 |
]
|
924 |
-
|
925 |
n_perturbation_genes = perturbation_emb.size()[1]
|
926 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
gene_cos_sims = pu.quant_cos_sims(
|
928 |
perturbation_emb,
|
929 |
original_emb,
|
|
|
40 |
from collections import defaultdict
|
41 |
|
42 |
import torch
|
43 |
+
from datasets import Dataset
|
44 |
from multiprocess import set_start_method
|
45 |
from tqdm.auto import trange
|
46 |
|
|
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
|
51 |
+
import datasets
|
52 |
+
datasets.logging.disable_progress_bar()
|
53 |
+
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
56 |
|
|
|
796 |
return example
|
797 |
|
798 |
total_batch_length = len(filtered_input_data)
|
799 |
+
|
800 |
+
|
801 |
if self.cell_states_to_model is None:
|
802 |
cos_sims_dict = defaultdict(list)
|
803 |
else:
|
|
|
882 |
)
|
883 |
|
884 |
##### CLS and Gene Embedding Mode #####
|
885 |
+
elif self.emb_mode == "cls_and_gene":
|
886 |
full_original_emb = get_embs(
|
887 |
model,
|
888 |
minibatch,
|
|
|
895 |
silent=True,
|
896 |
)
|
897 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
898 |
+
|
899 |
# remove indices that were perturbed
|
900 |
original_emb = pu.remove_perturbed_indices_set(
|
901 |
full_original_emb,
|
|
|
904 |
self.tokens_to_perturb,
|
905 |
minibatch["length"],
|
906 |
)
|
907 |
+
|
908 |
full_perturbation_emb = get_embs(
|
909 |
model,
|
910 |
perturbation_batch,
|
|
|
916 |
summary_stat=None,
|
917 |
silent=True,
|
918 |
)
|
919 |
+
|
920 |
# remove special tokens and padding
|
921 |
original_emb = original_emb[:, 1:-1, :]
|
922 |
if self.perturb_type == "overexpress":
|
|
|
927 |
perturbation_emb = full_perturbation_emb[
|
928 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
929 |
]
|
930 |
+
|
931 |
n_perturbation_genes = perturbation_emb.size()[1]
|
932 |
|
933 |
+
# truncate the original embedding as necessary
|
934 |
+
if self.perturb_type == "overexpress":
|
935 |
+
def calc_perturbation_length(ids):
|
936 |
+
if ids == [-100]:
|
937 |
+
return 0
|
938 |
+
else:
|
939 |
+
return len(ids)
|
940 |
+
|
941 |
+
max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
|
942 |
+
|
943 |
+
max_n_overflow = max(minibatch["n_overflow"])
|
944 |
+
if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
|
945 |
+
original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
|
946 |
+
elif perturbation_emb.size()[1] < original_emb.size()[1]:
|
947 |
+
original_emb = original_emb[:, 0:max_tensor_size, :]
|
948 |
+
|
949 |
gene_cos_sims = pu.quant_cos_sims(
|
950 |
perturbation_emb,
|
951 |
original_emb,
|