gf_v1: filled in isp single gene code

#237
This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .gitattributes +2 -2
  2. .pre-commit-config.yaml +0 -26
  3. .readthedocs.yaml +0 -19
  4. MANIFEST.in +3 -4
  5. README.md +11 -38
  6. config.json +8 -9
  7. docs/Makefile +0 -20
  8. docs/make.bat +0 -35
  9. docs/requirements.txt +0 -3
  10. docs/source/_static/css/custom.css +0 -40
  11. docs/source/_static/gf_logo.png +0 -0
  12. docs/source/about.rst +0 -49
  13. docs/source/api.rst +0 -51
  14. docs/source/conf.py +0 -80
  15. docs/source/geneformer.classifier.rst +0 -10
  16. docs/source/geneformer.emb_extractor.rst +0 -26
  17. docs/source/geneformer.in_silico_perturber.rst +0 -8
  18. docs/source/geneformer.in_silico_perturber_stats.rst +0 -25
  19. docs/source/geneformer.mtl_classifier.rst +0 -11
  20. docs/source/geneformer.tokenizer.rst +0 -15
  21. docs/source/getstarted.rst +0 -36
  22. docs/source/index.rst +0 -16
  23. examples/cell_classification.ipynb +0 -0
  24. examples/extract_and_plot_cell_embeddings.ipynb +4 -8
  25. examples/gene_classification.ipynb +0 -0
  26. examples/hyperparam_optimiz_for_disease_classifier.py +226 -0
  27. examples/in_silico_perturbation.ipynb +17 -66
  28. examples/multitask_cell_classification.ipynb +0 -420
  29. examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +1 -3
  30. examples/tokenizing_scRNAseq_data.ipynb +8 -27
  31. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
  32. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
  33. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
  34. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
  35. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
  36. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
  37. fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
  38. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +0 -24
  39. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +0 -3
  40. {gf-12L-30M-i2048 → geneformer-12L-30M}/config.json +0 -0
  41. {gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin +0 -0
  42. {gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin +0 -0
  43. geneformer/__init__.py +11 -33
  44. geneformer/classifier.py +0 -1563
  45. geneformer/classifier_utils.py +0 -648
  46. geneformer/collator_for_classification.py +74 -139
  47. geneformer/emb_extractor.py +279 -649
  48. geneformer/ensembl_mapping_dict_gc95M.pkl +0 -3
  49. geneformer/evaluation_utils.py +0 -287
  50. geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl +0 -3
.gitattributes CHANGED
@@ -14,11 +14,10 @@
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
- *.pkl filter=lfs diff=lfs merge=lfs -text
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
21
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
22
  *.tar.* filter=lfs diff=lfs merge=lfs -text
23
  *.tflite filter=lfs diff=lfs merge=lfs -text
24
  *.tgz filter=lfs diff=lfs merge=lfs -text
@@ -26,4 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
29
  model.safetensors filter=lfs diff=lfs merge=lfs -text
 
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ geneformer/gene_name_id_dict.pkl filter=lfs diff=lfs merge=lfs -text
29
  model.safetensors filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml DELETED
@@ -1,26 +0,0 @@
1
- # See https://pre-commit.com for more information
2
- # See https://pre-commit.com/hooks.html for more hooks
3
- repos:
4
- - repo: https://github.com/pre-commit/pre-commit-hooks
5
- rev: v3.2.0
6
- hooks:
7
- - id: trailing-whitespace
8
- - id: end-of-file-fixer
9
- - id: check-yaml
10
- - id: check-added-large-files
11
- - id: check-merge-conflict
12
- - id: mixed-line-ending
13
- - id: check-docstring-first
14
- - repo: https://github.com/pycqa/isort
15
- rev: 5.12.0
16
- hooks:
17
- - id: isort
18
- args: ["--profile", "black"]
19
- - repo: https://github.com/astral-sh/ruff-pre-commit
20
- # Ruff version.
21
- rev: v0.1.4
22
- hooks:
23
- # Run the Ruff linter.
24
- - id: ruff
25
- # Run the Ruff formatter.
26
- - id: ruff-format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.readthedocs.yaml DELETED
@@ -1,19 +0,0 @@
1
- # Read the Docs configuration file
2
-
3
- # Required
4
- version: 2
5
-
6
- # Set the OS, Python version and other tools you might need
7
- build:
8
- os: ubuntu-22.04
9
- tools:
10
- python: "3.10"
11
-
12
- # Build documentation in the "docs/" directory with Sphinx
13
- sphinx:
14
- configuration: docs/source/conf.py
15
-
16
- # Python requirements required build your documentation
17
- python:
18
- install:
19
- - requirements: docs/requirements.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MANIFEST.in CHANGED
@@ -1,4 +1,3 @@
1
- include geneformer/gene_median_dictionary_gc95M.pkl
2
- include geneformer/gene_name_id_dict_gc95M.pkl
3
- include geneformer/ensembl_mapping_dict_gc95M.pkl
4
- include geneformer/token_dictionary_gc95M.pkl
 
1
+ include geneformer/gene_median_dictionary.pkl
2
+ include geneformer/token_dictionary.pkl
3
+ include geneformer/gene_name_id_dict.pkl
 
README.md CHANGED
@@ -1,43 +1,22 @@
1
  ---
2
  datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
4
- tags:
5
- - single-cell
6
- - genomics
7
  ---
8
  # Geneformer
9
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
10
 
11
- - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
12
- - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
13
- - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
14
 
15
  # Model Description
16
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
17
 
18
- Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
19
 
20
- The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
21
 
22
- We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
23
 
24
- During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
25
-
26
- The repository includes the following pretrained models:
27
-
28
- L=layers\
29
- M=millions of cells used for pretraining\
30
- i=input size\
31
- (pretraining date)
32
-
33
- - GF-6L-30M-i2048 (June 2021)
34
- - GF-12L-30M-i2048 (June 2021)
35
- - GF-12L-95M-i4096 (April 2024)
36
- - GF-20L-95M-i4096 (April 2024)
37
-
38
- The current default model in the main directory of the repository is GF-12L-95M-i4096.
39
-
40
- The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
41
 
42
  # Application
43
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
@@ -45,7 +24,7 @@ The pretrained Geneformer model can be used directly for zero-shot learning, for
45
  Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
46
 
47
  *Fine-tuning*:
48
- - transcription factor dosage sensitivity
49
  - chromatin dynamics (bivalently marked promoters)
50
  - transcription factor regulatory range
51
  - gene network centrality
@@ -67,11 +46,9 @@ Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) inc
67
  - in silico perturbation to determine transcription factor cooperativity
68
 
69
  # Installation
70
- In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install (~20s):
71
 
72
  ```bash
73
- # Make sure you have git-lfs installed (https://git-lfs.com)
74
- git lfs install
75
  git clone https://huggingface.co/ctheodoris/Geneformer
76
  cd Geneformer
77
  pip install .
@@ -85,10 +62,6 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
85
  - extracting and plotting cell embeddings
86
  - in silico perturbation
87
 
88
- Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
89
-
90
- Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
91
 
92
- # Citations
93
- - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
94
- - H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
 
1
  ---
2
  datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
 
 
 
4
  ---
5
  # Geneformer
6
+ Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
7
 
8
+ See [our manuscript](https://rdcu.be/ddrx0) for details.
 
 
9
 
10
  # Model Description
11
+ Geneformer is a foundation transformer model pretrained on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
12
 
13
+ The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
14
 
15
+ We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
16
 
17
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents a pretrained deep learning model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
18
 
19
+ In [our manuscript](https://rdcu.be/ddrx0), we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within this repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Application
22
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
 
24
  Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
25
 
26
  *Fine-tuning*:
27
+ - transcription factor dosage sensitivity
28
  - chromatin dynamics (bivalently marked promoters)
29
  - transcription factor regulatory range
30
  - gene network centrality
 
46
  - in silico perturbation to determine transcription factor cooperativity
47
 
48
  # Installation
49
+ In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install:
50
 
51
  ```bash
 
 
52
  git clone https://huggingface.co/ctheodoris/Geneformer
53
  cd Geneformer
54
  pip install .
 
62
  - extracting and plotting cell embeddings
63
  - in silico perturbation
64
 
65
+ Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
 
 
66
 
67
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
 
 
config.json CHANGED
@@ -3,22 +3,21 @@
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
- "vocab_size": 20275
24
  }
 
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
+ "gradient_checkpointing": false,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 256,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 512,
12
  "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 2048,
14
  "model_type": "bert",
15
+ "num_attention_heads": 4,
16
+ "num_hidden_layers": 6,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0",
 
20
  "type_vocab_size": 2,
21
  "use_cache": true,
22
+ "vocab_size": 25426
23
  }
docs/Makefile DELETED
@@ -1,20 +0,0 @@
1
- # Minimal makefile for Sphinx documentation
2
- #
3
-
4
- # You can set these variables from the command line, and also
5
- # from the environment for the first two.
6
- SPHINXOPTS ?=
7
- SPHINXBUILD ?= sphinx-build
8
- SOURCEDIR = source
9
- BUILDDIR = build
10
-
11
- # Put it first so that "make" without argument is like "make help".
12
- help:
13
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
-
15
- .PHONY: help Makefile
16
-
17
- # Catch-all target: route all unknown targets to Sphinx using the new
18
- # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
- %: Makefile
20
- @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/make.bat DELETED
@@ -1,35 +0,0 @@
1
- @ECHO OFF
2
-
3
- pushd %~dp0
4
-
5
- REM Command file for Sphinx documentation
6
-
7
- if "%SPHINXBUILD%" == "" (
8
- set SPHINXBUILD=sphinx-build
9
- )
10
- set SOURCEDIR=source
11
- set BUILDDIR=build
12
-
13
- %SPHINXBUILD% >NUL 2>NUL
14
- if errorlevel 9009 (
15
- echo.
16
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
- echo.installed, then set the SPHINXBUILD environment variable to point
18
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
- echo.may add the Sphinx directory to PATH.
20
- echo.
21
- echo.If you don't have Sphinx installed, grab it from
22
- echo.https://www.sphinx-doc.org/
23
- exit /b 1
24
- )
25
-
26
- if "%1" == "" goto help
27
-
28
- %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
- goto end
30
-
31
- :help
32
- %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
-
34
- :end
35
- popd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/requirements.txt DELETED
@@ -1,3 +0,0 @@
1
- .
2
- sphinx_rtd_theme==2.0.0
3
- nbsphinx==0.9.3
 
 
 
 
docs/source/_static/css/custom.css DELETED
@@ -1,40 +0,0 @@
1
- /* top left logo */
2
- .wy-side-nav-search, .wy-nav-top {
3
- background: linear-gradient(15deg, #13547a 0%, #80d0c7 100%);
4
- }
5
-
6
-
7
- /* unvisited link */
8
- .wy-nav-content a:link {
9
- color: #067abd;
10
- }
11
-
12
- /* visited link */
13
- .wy-nav-content a:visited {
14
- color: #4b827c;
15
- }
16
-
17
- /* mouse over link */
18
- .wy-nav-content a:hover {
19
- color: #80d0c7;
20
- }
21
-
22
- /* selected link */
23
- .wy-nav-content a:active {
24
- color: #4b827c;
25
- }
26
-
27
- /* class object */
28
- .sig.sig-object {
29
- padding: 5px 5px 5px 5px;
30
- background-color: #ececec;
31
- border-style: solid;
32
- border-color: black;
33
- border-width: 1px 0;
34
- }
35
-
36
- /* parameter object */
37
- dt {
38
- padding: 5px 5px 5px 5px;
39
- background-color: #ececec;
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/_static/gf_logo.png DELETED
Binary file (48.2 kB)
 
docs/source/about.rst DELETED
@@ -1,49 +0,0 @@
1
- About
2
- =====
3
-
4
- Model Description
5
- -----------------
6
-
7
- **Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus of human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
8
-
9
- In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the original 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
10
-
11
- Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-12L-30M-i2048/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
12
-
13
- Also see `our 2024 manuscript <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_, for details of the `expanded model <https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors>`_ trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
14
-
15
- Application
16
- -----------
17
-
18
- The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
19
-
20
- Example applications demonstrated in `our manuscript <https://rdcu.be/ddrx0>`_ include:
21
-
22
- | *Fine-tuning*:
23
- | - transcription factor dosage sensitivity
24
- | - chromatin dynamics (bivalently marked promoters)
25
- | - transcription factor regulatory range
26
- | - gene network centrality
27
- | - transcription factor targets
28
- | - cell type annotation
29
- | - batch integration
30
- | - cell state classification across differentiation
31
- | - disease classification
32
- | - in silico perturbation to determine disease-driving genes
33
- | - in silico treatment to determine candidate therapeutic targets
34
-
35
- | *Zero-shot learning*:
36
- | - batch integration
37
- | - gene context specificity
38
- | - in silico reprogramming
39
- | - in silico differentiation
40
- | - in silico perturbation to determine impact on cell state
41
- | - in silico perturbation to determine transcription factor targets
42
- | - in silico perturbation to determine transcription factor cooperativity
43
-
44
- Citations
45
- ---------
46
-
47
- | C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
48
-
49
- | H Chen \*, M S Venkatesh \*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka †, C V Theodoris † #. `Quantized multi-task learning for context-specific representations of gene network dynamics. <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_ *bioRxiv*, 19 Aug 2024. (\* co-first authors, † co-senior authors, # corresponding author)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/api.rst DELETED
@@ -1,51 +0,0 @@
1
- API
2
- ===
3
-
4
- Tokenizer
5
- ---------
6
-
7
- .. toctree::
8
- :maxdepth: 1
9
-
10
- geneformer.tokenizer
11
-
12
- Classifier
13
- ----------
14
-
15
- .. toctree::
16
- :maxdepth: 1
17
-
18
- geneformer.classifier
19
-
20
- Multitask Classifier
21
- --------------------
22
-
23
- .. toctree::
24
- :maxdepth: 1
25
-
26
- geneformer.mtl_classifier
27
-
28
- Embedding Extractor
29
- -------------------
30
-
31
- .. toctree::
32
- :maxdepth: 1
33
-
34
- geneformer.emb_extractor
35
-
36
- In Silico Perturber
37
- -------------------
38
-
39
- .. toctree::
40
- :maxdepth: 1
41
-
42
- geneformer.in_silico_perturber
43
-
44
-
45
- In Silico Perturber Stats
46
- -------------------------
47
-
48
- .. toctree::
49
- :maxdepth: 1
50
-
51
- geneformer.in_silico_perturber_stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/conf.py DELETED
@@ -1,80 +0,0 @@
1
- # Configuration file for the Sphinx documentation builder.
2
- #
3
- # For the full list of built-in configuration values, see the documentation:
4
- # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
-
6
- import pathlib
7
- import re
8
- import sys
9
-
10
- from sphinx.ext import autodoc
11
-
12
- sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix())
13
-
14
-
15
- # -- Project information -----------------------------------------------------
16
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
17
-
18
- project = "geneformer"
19
- copyright = "2024, Christina Theodoris"
20
- author = "Christina Theodoris"
21
- release = "0.1.0"
22
- repository_url = "https://huggingface.co/ctheodoris/Geneformer"
23
-
24
- # -- General configuration ---------------------------------------------------
25
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
26
-
27
- extensions = [
28
- "sphinx.ext.autodoc",
29
- "sphinx.ext.autosummary",
30
- "nbsphinx",
31
- "sphinx.ext.viewcode",
32
- "sphinx.ext.doctest",
33
- ]
34
-
35
- templates_path = ["_templates"]
36
- exclude_patterns = [
37
- "**.ipynb_checkpoints",
38
- ]
39
- autoclass_content = "both"
40
-
41
-
42
- class MockedClassDocumenter(autodoc.ClassDocumenter):
43
- def add_line(self, line: str, source: str, *lineno: int) -> None:
44
- if line == " Bases: :py:class:`object`":
45
- return
46
- super().add_line(line, source, *lineno)
47
-
48
-
49
- autodoc.ClassDocumenter = MockedClassDocumenter
50
- add_module_names = False
51
-
52
-
53
- def process_signature(app, what, name, obj, options, signature, return_annotation):
54
- # loop through each line in the docstring and replace path with
55
- # the generic path text
56
- signature = re.sub(r"PosixPath\(.*?\)", "FILEPATH", signature)
57
- return (signature, None)
58
-
59
-
60
- def setup(app):
61
- app.connect("autodoc-process-signature", process_signature)
62
-
63
-
64
- # -- Options for HTML output -------------------------------------------------
65
- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
66
-
67
- html_theme = "sphinx_rtd_theme"
68
- html_show_sphinx = False
69
- html_static_path = ["_static"]
70
- html_logo = "_static/gf_logo.png"
71
- html_theme_options = {
72
- "collapse_navigation": False,
73
- "sticky_navigation": True,
74
- "navigation_depth": 3,
75
- "logo_only": True,
76
- }
77
- html_css_files = [
78
- "css/custom.css",
79
- ]
80
- html_show_sourcelink = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/geneformer.classifier.rst DELETED
@@ -1,10 +0,0 @@
1
- geneformer.classifier
2
- =====================
3
-
4
- .. automodule:: geneformer.classifier
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members:
9
- valid_option_dict,
10
- validate_options
 
 
 
 
 
 
 
 
 
 
 
docs/source/geneformer.emb_extractor.rst DELETED
@@ -1,26 +0,0 @@
1
- geneformer.emb\_extractor
2
- =========================
3
-
4
- .. automodule:: geneformer.emb_extractor
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members:
9
- accumulate_tdigests,
10
- gen_heatmap_class_colors,
11
- gen_heatmap_class_dict,
12
- get_embs,
13
- label_cell_embs,
14
- label_gene_embs,
15
- make_colorbar,
16
- plot_heatmap,
17
- plot_umap,
18
- summarize_gene_embs,
19
- tdigest_mean,
20
- tdigest_median,
21
- test_emb,
22
- update_tdigest_dict,
23
- update_tdigest_dict_mean,
24
- update_tdigest_dict_median,
25
- valid_option_dict,
26
- validate_options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/geneformer.in_silico_perturber.rst DELETED
@@ -1,8 +0,0 @@
1
- geneformer.in\_silico\_perturber
2
- =======================================
3
-
4
- .. automodule:: geneformer.in_silico_perturber
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
 
 
 
 
 
 
 
 
 
docs/source/geneformer.in_silico_perturber_stats.rst DELETED
@@ -1,25 +0,0 @@
1
- geneformer.in\_silico\_perturber\_stats
2
- ==============================================
3
-
4
- .. automodule:: geneformer.in_silico_perturber_stats
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members:
9
- find,
10
- get_fdr,
11
- get_gene_list,
12
- get_impact_component,
13
- invert_dict,
14
- isp_aggregate_gene_shifts,
15
- isp_aggregate_grouped_perturb,
16
- isp_stats_mixture_model,
17
- isp_stats_to_goal_state,
18
- isp_stats_vs_null,
19
- n_detections,
20
- read_dict,
21
- read_dictionaries,
22
- token_to_gene_name,
23
- token_tuple_to_ensembl_ids,
24
- valid_option_dict,
25
- validate_options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/geneformer.mtl_classifier.rst DELETED
@@ -1,11 +0,0 @@
1
- geneformer.mtl\_classifier
2
- ==========================
3
-
4
- .. automodule:: geneformer.mtl_classifier
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members:
9
- valid_option_dict,
10
- validate_options,
11
- validate_additional_options
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/geneformer.tokenizer.rst DELETED
@@ -1,15 +0,0 @@
1
- geneformer.tokenizer
2
- ====================
3
-
4
- .. automodule:: geneformer.tokenizer
5
- :members:
6
- :undoc-members:
7
- :show-inheritance:
8
- :exclude-members:
9
- create_dataset,
10
- tokenize_anndata,
11
- tokenize_files,
12
- tokenize_loom,
13
- rank_genes,
14
- tokenize_cell,
15
- sum_ensembl_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/getstarted.rst DELETED
@@ -1,36 +0,0 @@
1
- Getting Started
2
- ===============
3
-
4
- Installation
5
- ------------
6
-
7
- Geneformer installation instructions.
8
-
9
- Make sure you have git-lfs installed (https://git-lfs.com).
10
-
11
- .. code-block:: bash
12
-
13
- git lfs install
14
- git clone https://huggingface.co/ctheodoris/Geneformer
15
- cd Geneformer
16
- pip install .
17
-
18
-
19
- Tutorials
20
- ---------
21
-
22
- | See `examples <https://huggingface.co/ctheodoris/Geneformer/tree/main/examples>`_ for:
23
- | - tokenizing transcriptomes
24
- | - pretraining
25
- | - hyperparameter tuning
26
- | - fine-tuning
27
- | - extracting and plotting cell embeddings
28
- | - in silico perturbation
29
-
30
- Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the `example_input_files directory <https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files>`_ in the dataset repository, but these only represent a few example fine-tuning applications.
31
-
32
-
33
- Tips
34
- ----
35
-
36
- Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/source/index.rst DELETED
@@ -1,16 +0,0 @@
1
- Geneformer
2
- ==========
3
-
4
- Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in network biology.
5
-
6
- See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
7
-
8
- Table of Contents
9
- -----------------
10
-
11
- .. toctree::
12
- :maxdepth: 2
13
-
14
- about
15
- getstarted
16
- api
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/cell_classification.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -18,8 +18,6 @@
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
22
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
23
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
24
  " num_classes=3,\n",
25
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
@@ -28,13 +26,11 @@
28
  " emb_label=[\"disease\",\"cell_type\"],\n",
29
  " labels_to_plot=[\"disease\"],\n",
30
  " forward_batch_size=200,\n",
31
- " nproc=16,\n",
32
- " token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n",
33
  "\n",
34
  "# extracts embedding from input data\n",
35
- "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
36
- "# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
37
- "embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
38
  " \"path/to/input_data/\",\n",
39
  " \"path/to/output_directory/\",\n",
40
  " \"output_prefix\")\n"
@@ -132,7 +128,7 @@
132
  "name": "python",
133
  "nbconvert_exporter": "python",
134
  "pygments_lexer": "ipython3",
135
- "version": "3.10.15"
136
  }
137
  },
138
  "nbformat": 4,
 
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
 
 
21
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
22
  " num_classes=3,\n",
23
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
 
26
  " emb_label=[\"disease\",\"cell_type\"],\n",
27
  " labels_to_plot=[\"disease\"],\n",
28
  " forward_batch_size=200,\n",
29
+ " nproc=16)\n",
 
30
  "\n",
31
  "# extracts embedding from input data\n",
32
+ "# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
33
+ "embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
 
34
  " \"path/to/input_data/\",\n",
35
  " \"path/to/output_directory/\",\n",
36
  " \"output_prefix\")\n"
 
128
  "name": "python",
129
  "nbconvert_exporter": "python",
130
  "pygments_lexer": "ipython3",
131
+ "version": "3.10.11"
132
  }
133
  },
134
  "nbformat": 4,
examples/gene_classification.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
examples/hyperparam_optimiz_for_disease_classifier.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # hyperparameter optimization with raytune for disease classification
5
+
6
+ # imports
7
+ import os
8
+ import subprocess
9
+ GPU_NUMBER = [0,1,2,3]
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
11
+ os.environ["NCCL_DEBUG"] = "INFO"
12
+ os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
13
+ os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
14
+
15
+ # initiate runtime environment for raytune
16
+ import pyarrow # must occur prior to ray import
17
+ import ray
18
+ from ray import tune
19
+ from ray.tune import ExperimentAnalysis
20
+ from ray.tune.suggest.hyperopt import HyperOptSearch
21
+ ray.shutdown() #engage new ray session
22
+ runtime_env = {"conda": "base",
23
+ "env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
24
+ ray.init(runtime_env=runtime_env)
25
+
26
+ def initialize_ray_with_check(ip_address):
27
+ """
28
+ Initialize Ray with a specified IP address and check its status and accessibility.
29
+
30
+ Args:
31
+ - ip_address (str): The IP address (with port) to initialize Ray.
32
+
33
+ Returns:
34
+ - bool: True if initialization was successful and dashboard is accessible, False otherwise.
35
+ """
36
+ try:
37
+ ray.init(address=ip_address)
38
+ print(ray.nodes())
39
+
40
+ services = ray.get_webui_url()
41
+ if not services:
42
+ raise RuntimeError("Ray dashboard is not accessible.")
43
+ else:
44
+ print(f"Ray dashboard is accessible at: {services}")
45
+ return True
46
+ except Exception as e:
47
+ print(f"Error initializing Ray: {e}")
48
+ return False
49
+
50
+ # Usage:
51
+ ip = 'your_ip:xxxx' # Replace with your actual IP address and port
52
+ if initialize_ray_with_check(ip):
53
+ print("Ray initialized successfully.")
54
+ else:
55
+ print("Error during Ray initialization.")
56
+
57
+ import datetime
58
+ import numpy as np
59
+ import pandas as pd
60
+ import random
61
+ import seaborn as sns; sns.set()
62
+ from collections import Counter
63
+ from datasets import load_from_disk
64
+ from scipy.stats import ranksums
65
+ from sklearn.metrics import accuracy_score
66
+ from transformers import BertForSequenceClassification
67
+ from transformers import Trainer
68
+ from transformers.training_args import TrainingArguments
69
+
70
+ from geneformer import DataCollatorForCellClassification
71
+
72
+ # number of CPU cores
73
+ num_proc=30
74
+
75
+ # load train dataset with columns:
76
+ # cell_type (annotation of each cell's type)
77
+ # disease (healthy or disease state)
78
+ # individual (unique ID for each patient)
79
+ # length (length of that cell's rank value encoding)
80
+ train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
81
+
82
+ # filter dataset for given cell_type
83
+ def if_cell_type(example):
84
+ return example["cell_type"].startswith("Cardiomyocyte")
85
+
86
+ trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
87
+
88
+ # create dictionary of disease states : label ids
89
+ target_names = ["healthy", "disease1", "disease2"]
90
+ target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
91
+
92
+ trainset_v3 = trainset_v2.rename_column("disease","label")
93
+
94
+ # change labels to numerical ids
95
+ def classes_to_ids(example):
96
+ example["label"] = target_name_id_dict[example["label"]]
97
+ return example
98
+
99
+ trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
100
+
101
+ # separate into train, validation, test sets
102
+ indiv_set = set(trainset_v4["individual"])
103
+ random.seed(42)
104
+ train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
105
+ eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
106
+ valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
107
+ test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
108
+
109
+ def if_train(example):
110
+ return example["individual"] in train_indiv
111
+
112
+ classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
113
+
114
+ def if_valid(example):
115
+ return example["individual"] in valid_indiv
116
+
117
+ classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
118
+
119
+ # define output directory path
120
+ current_date = datetime.datetime.now()
121
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
122
+ output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
123
+
124
+ # ensure not overwriting previously saved model
125
+ saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
126
+ if os.path.isfile(saved_model_test) == True:
127
+ raise Exception("Model already saved to this directory.")
128
+
129
+ # make output directory
130
+ subprocess.call(f'mkdir {output_dir}', shell=True)
131
+
132
+ # set training parameters
133
+ # how many pretrained layers to freeze
134
+ freeze_layers = 2
135
+ # batch size for training and eval
136
+ geneformer_batch_size = 12
137
+ # number of epochs
138
+ epochs = 1
139
+ # logging steps
140
+ logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
141
+
142
+ # define function to initiate model
143
+ def model_init():
144
+ model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
145
+ num_labels=len(target_names),
146
+ output_attentions = False,
147
+ output_hidden_states = False)
148
+ if freeze_layers is not None:
149
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
150
+ for module in modules_to_freeze:
151
+ for param in module.parameters():
152
+ param.requires_grad = False
153
+
154
+ model = model.to("cuda:0")
155
+ return model
156
+
157
+ # define metrics
158
+ # note: macro f1 score recommended for imbalanced multiclass classifiers
159
+ def compute_metrics(pred):
160
+ labels = pred.label_ids
161
+ preds = pred.predictions.argmax(-1)
162
+ # calculate accuracy using sklearn's function
163
+ acc = accuracy_score(labels, preds)
164
+ return {
165
+ 'accuracy': acc,
166
+ }
167
+
168
+ # set training arguments
169
+ training_args = {
170
+ "do_train": True,
171
+ "do_eval": True,
172
+ "evaluation_strategy": "steps",
173
+ "eval_steps": logging_steps,
174
+ "logging_steps": logging_steps,
175
+ "group_by_length": True,
176
+ "length_column_name": "length",
177
+ "disable_tqdm": True,
178
+ "skip_memory_metrics": True, # memory tracker causes errors in raytune
179
+ "per_device_train_batch_size": geneformer_batch_size,
180
+ "per_device_eval_batch_size": geneformer_batch_size,
181
+ "num_train_epochs": epochs,
182
+ "load_best_model_at_end": True,
183
+ "output_dir": output_dir,
184
+ }
185
+
186
+ training_args_init = TrainingArguments(**training_args)
187
+
188
+ # create the trainer
189
+ trainer = Trainer(
190
+ model_init=model_init,
191
+ args=training_args_init,
192
+ data_collator=DataCollatorForCellClassification(),
193
+ train_dataset=classifier_trainset,
194
+ eval_dataset=classifier_validset,
195
+ compute_metrics=compute_metrics,
196
+ )
197
+
198
+ # specify raytune hyperparameter search space
199
+ ray_config = {
200
+ "num_train_epochs": tune.choice([epochs]),
201
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
202
+ "weight_decay": tune.uniform(0.0, 0.3),
203
+ "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
204
+ "warmup_steps": tune.uniform(100, 2000),
205
+ "seed": tune.uniform(0,100),
206
+ "per_device_train_batch_size": tune.choice([geneformer_batch_size])
207
+ }
208
+
209
+ hyperopt_search = HyperOptSearch(
210
+ metric="eval_accuracy", mode="max")
211
+
212
+ # optimize hyperparameters
213
+ trainer.hyperparameter_search(
214
+ direction="maximize",
215
+ backend="ray",
216
+ resources_per_trial={"cpu":8,"gpu":1},
217
+ hp_space=lambda _: ray_config,
218
+ search_alg=hyperopt_search,
219
+ n_trials=100, # number of trials
220
+ progress_reporter=tune.CLIReporter(max_report_frequency=600,
221
+ sort_by_metric=True,
222
+ max_progress_rows=100,
223
+ mode="max",
224
+ metric="eval_accuracy",
225
+ metric_columns=["loss", "eval_loss", "eval_accuracy"])
226
+ )
examples/in_silico_perturbation.ipynb CHANGED
@@ -8,80 +8,35 @@
8
  "outputs": [],
9
  "source": [
10
  "from geneformer import InSilicoPerturber\n",
11
- "from geneformer import InSilicoPerturberStats\n",
12
- "from geneformer import EmbExtractor"
13
- ]
14
- },
15
- {
16
- "cell_type": "markdown",
17
- "id": "cbd6851c-060e-4967-b816-e605ffe58b23",
18
- "metadata": {
19
- "tags": []
20
- },
21
- "source": [
22
- "### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
23
- ]
24
- },
25
- {
26
- "cell_type": "code",
27
- "execution_count": null,
28
- "id": "c53e98cd-c603-4878-82ba-db471181bb55",
29
- "metadata": {},
30
- "outputs": [],
31
- "source": [
32
- "# first obtain start, goal, and alt embedding positions\n",
33
- "# this function was changed to be separate from perturb_data\n",
34
- "# to avoid repeating calcuations when parallelizing perturb_data\n",
35
- "cell_states_to_model={\"state_key\": \"disease\", \n",
36
- " \"start_state\": \"dcm\", \n",
37
- " \"goal_state\": \"nf\", \n",
38
- " \"alt_states\": [\"hcm\"]}\n",
39
- "\n",
40
- "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
- "\n",
42
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
43
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
44
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
45
- "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
46
- " num_classes=3,\n",
47
- " filter_data=filter_data_dict,\n",
48
- " max_ncells=1000,\n",
49
- " emb_layer=0,\n",
50
- " summary_stat=\"exact_mean\",\n",
51
- " forward_batch_size=256,\n",
52
- " nproc=16)\n",
53
- "\n",
54
- "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
55
- " \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
56
- " \"path/to/input_data\",\n",
57
- " \"path/to/output_directory\",\n",
58
- " \"output_prefix\")"
59
  ]
60
  },
61
  {
62
  "cell_type": "code",
63
  "execution_count": null,
64
- "id": "981e1190-62da-4543-b7d3-6e2a2d6a6d56",
65
  "metadata": {
66
  "tags": []
67
  },
68
  "outputs": [],
69
  "source": [
70
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
71
- "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
72
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
73
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
74
  " perturb_rank_shift=None,\n",
75
  " genes_to_perturb=\"all\",\n",
76
  " combos=0,\n",
77
  " anchor_gene=None,\n",
78
- " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
79
  " num_classes=3,\n",
80
  " emb_mode=\"cell\",\n",
81
  " cell_emb_style=\"mean_pool\",\n",
82
- " filter_data=filter_data_dict,\n",
83
- " cell_states_to_model=cell_states_to_model,\n",
84
- " state_embs_dict=state_embs_dict,\n",
 
 
85
  " max_ncells=2000,\n",
86
  " emb_layer=0,\n",
87
  " forward_batch_size=400,\n",
@@ -96,10 +51,9 @@
96
  "outputs": [],
97
  "source": [
98
  "# outputs intermediate files from in silico perturbation\n",
99
- "\n",
100
- "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
101
  " \"path/to/input_data\",\n",
102
- " \"path/to/isp_output_directory\",\n",
103
  " \"output_prefix\")"
104
  ]
105
  },
@@ -110,14 +64,11 @@
110
  "metadata": {},
111
  "outputs": [],
112
  "source": [
113
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
114
- "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
115
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
116
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
117
  " genes_perturbed=\"all\",\n",
118
  " combos=0,\n",
119
  " anchor_gene=None,\n",
120
- " cell_states_to_model=cell_states_to_model)"
121
  ]
122
  },
123
  {
@@ -128,9 +79,9 @@
128
  "outputs": [],
129
  "source": [
130
  "# extracts data from intermediate files and processes stats to output in final .csv\n",
131
- "ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n",
132
  " None,\n",
133
- " \"path/to/isp_stats_output_directory\",\n",
134
  " \"output_prefix\")"
135
  ]
136
  }
@@ -151,7 +102,7 @@
151
  "name": "python",
152
  "nbconvert_exporter": "python",
153
  "pygments_lexer": "ipython3",
154
- "version": "3.10.15"
155
  }
156
  },
157
  "nbformat": 4,
 
8
  "outputs": [],
9
  "source": [
10
  "from geneformer import InSilicoPerturber\n",
11
+ "from geneformer import InSilicoPerturberStats"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ]
13
  },
14
  {
15
  "cell_type": "code",
16
  "execution_count": null,
17
+ "id": "67b44366-f255-4415-a865-6a27a8ffcce7",
18
  "metadata": {
19
  "tags": []
20
  },
21
  "outputs": [],
22
  "source": [
23
+ "# in silico perturbation in deletion mode to determine genes whose \n",
24
+ "# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
25
+ "# the embedding towards non-failing (nf) state\n",
26
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
27
  " perturb_rank_shift=None,\n",
28
  " genes_to_perturb=\"all\",\n",
29
  " combos=0,\n",
30
  " anchor_gene=None,\n",
31
+ " model_type=\"CellClassifier\",\n",
32
  " num_classes=3,\n",
33
  " emb_mode=\"cell\",\n",
34
  " cell_emb_style=\"mean_pool\",\n",
35
+ " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
+ " cell_states_to_model={'state_key': 'disease', \n",
37
+ " 'start_state': 'dcm', \n",
38
+ " 'goal_state': 'nf', \n",
39
+ " 'alt_states': ['hcm']},\n",
40
  " max_ncells=2000,\n",
41
  " emb_layer=0,\n",
42
  " forward_batch_size=400,\n",
 
51
  "outputs": [],
52
  "source": [
53
  "# outputs intermediate files from in silico perturbation\n",
54
+ "isp.perturb_data(\"path/to/model\",\n",
 
55
  " \"path/to/input_data\",\n",
56
+ " \"path/to/output_directory\",\n",
57
  " \"output_prefix\")"
58
  ]
59
  },
 
64
  "metadata": {},
65
  "outputs": [],
66
  "source": [
 
 
 
67
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
68
  " genes_perturbed=\"all\",\n",
69
  " combos=0,\n",
70
  " anchor_gene=None,\n",
71
+ " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
72
  ]
73
  },
74
  {
 
79
  "outputs": [],
80
  "source": [
81
  "# extracts data from intermediate files and processes stats to output in final .csv\n",
82
+ "ispstats.get_stats(\"path/to/input_data\",\n",
83
  " None,\n",
84
+ " \"path/to/output_directory\",\n",
85
  " \"output_prefix\")"
86
  ]
87
  }
 
102
  "name": "python",
103
  "nbconvert_exporter": "python",
104
  "pygments_lexer": "ipython3",
105
+ "version": "3.10.11"
106
  }
107
  },
108
  "nbformat": 4,
examples/multitask_cell_classification.ipynb DELETED
@@ -1,420 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "866f100c-e11a-4e7b-a37c-831775d845a7",
6
- "metadata": {},
7
- "source": [
8
- "# Geneformer Multi-Task Cell Classifier Tutorial\n",
9
- "\n",
10
- "This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
16
- "metadata": {},
17
- "source": [
18
- "## 1. Installation and Imports\n",
19
- "\n",
20
- "First import the necessary modules."
21
- ]
22
- },
23
- {
24
- "cell_type": "code",
25
- "execution_count": 3,
26
- "id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
27
- "metadata": {},
28
- "outputs": [],
29
- "source": [
30
- "from geneformer import MTLClassifier"
31
- ]
32
- },
33
- {
34
- "cell_type": "markdown",
35
- "id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
36
- "metadata": {},
37
- "source": [
38
- "## 2. Set up Paths and Parameters\n",
39
- "\n",
40
- "Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": null,
46
- "id": "04a04197-8e45-47f8-a86f-202209ea10ae",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "# Define paths\n",
51
- "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
52
- "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
53
- "train_path = \"/path/to/train/data.dataset\"\n",
54
- "val_path = \"/path/to/val/data.dataset\"\n",
55
- "test_path = \"/path/to/test/data.dataset\"\n",
56
- "results_dir = \"/path/to/results/directory\"\n",
57
- "model_save_path = \"/path/to/model/save/path\"\n",
58
- "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
59
- "\n",
60
- "# Define tasks and hyperparameters\n",
61
- "# task_columns should be a list of column names from your dataset\n",
62
- "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
63
- "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
64
- "\n",
65
- "hyperparameters = {\n",
66
- " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
67
- " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
68
- " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
69
- " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
70
- " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
71
- " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
72
- "}"
73
- ]
74
- },
75
- {
76
- "cell_type": "markdown",
77
- "id": "31857690-a739-435a-aefd-f171fafc1b78",
78
- "metadata": {},
79
- "source": [
80
- "In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
81
- "1. Identifying the cell type\n",
82
- "2. Determining the disease state\n",
83
- "3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
84
- "\n",
85
- "These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
86
- "\n",
87
- "For example, your dataset might look something like this:\n",
88
- "\n",
89
- " | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
90
- " |----------------|-----------|-----|-----------|---------------|\n",
91
- " | cell1 | ... | ... | neuron | healthy |\n",
92
- " | cell2 | ... | ... | astrocyte | diseased |\n",
93
- " | ... | ... | ... | ... | ... |\n",
94
- "The model will learn to predict classes within 'cell_type' and 'disease_state' "
95
- ]
96
- },
97
- {
98
- "cell_type": "markdown",
99
- "id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
100
- "metadata": {},
101
- "source": [
102
- "## 3. Initialize the MTLClassifier\n",
103
- "\n",
104
- "Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": null,
110
- "id": "e27caac9-670c-409d-9313-50201c665cb9",
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "mc = MTLClassifier(\n",
115
- " task_columns=task_columns, # Our defined classification tasks\n",
116
- " study_name=\"MTLClassifier_example\",\n",
117
- " pretrained_path=pretrained_path,\n",
118
- " train_path=train_path,\n",
119
- " val_path=val_path,\n",
120
- " test_path=test_path,\n",
121
- " model_save_path=model_save_path,\n",
122
- " results_dir=results_dir,\n",
123
- " tensorboard_log_dir=tensorboard_log_dir,\n",
124
- " hyperparameters=hyperparameters,\n",
125
- " n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
126
- " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
127
- " batch_size=8, # Adjust based on available GPU memory\n",
128
- " seed=42\n",
129
- ")"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "0d729444-e3ad-4584-9659-0c464ac97462",
135
- "metadata": {},
136
- "source": [
137
- "## 4. Run Hyperparameter Optimization\n",
138
- "\n",
139
- "Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": null,
145
- "id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "mc.run_optuna_study()"
150
- ]
151
- },
152
- {
153
- "cell_type": "markdown",
154
- "id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
155
- "metadata": {},
156
- "source": [
157
- "## 5. Evaluate the Model on Test Data\n",
158
- "\n",
159
- "After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
166
- "metadata": {},
167
- "outputs": [],
168
- "source": [
169
- "mc.load_and_evaluate_test_model()"
170
- ]
171
- },
172
- {
173
- "cell_type": "markdown",
174
- "id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
175
- "metadata": {},
176
- "source": [
177
- "## 6. (Optional) Manual Hyperparameter Tuning\n",
178
- "\n",
179
- "If you prefer to set hyperparameters manually, you can use the following approach:"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": null,
185
- "id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
186
- "metadata": {},
187
- "outputs": [],
188
- "source": [
189
- "manual_hyperparameters = {\n",
190
- " \"learning_rate\": 0.001,\n",
191
- " \"warmup_ratio\": 0.01,\n",
192
- " \"weight_decay\": 0.1,\n",
193
- " \"dropout_rate\": 0.1,\n",
194
- " \"lr_scheduler_type\": \"cosine\",\n",
195
- " \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
196
- " \"max_layers_to_freeze\": 2\n",
197
- "}\n",
198
- "\n",
199
- "mc_manual = MTLClassifier(\n",
200
- " task_columns=task_columns,\n",
201
- " study_name=\"mtl_manual\",\n",
202
- " pretrained_path=pretrained_path,\n",
203
- " train_path=train_path,\n",
204
- " val_path=val_path,\n",
205
- " test_path=test_path,\n",
206
- " model_save_path=model_save_path,\n",
207
- " results_dir=results_dir,\n",
208
- " tensorboard_log_dir=tensorboard_log_dir,\n",
209
- " manual_hyperparameters=manual_hyperparameters,\n",
210
- " use_manual_hyperparameters=True,\n",
211
- " epochs=10,\n",
212
- " batch_size=32,\n",
213
- " seed=42\n",
214
- ")\n",
215
- "\n",
216
- "mc_manual.run_manual_tuning()"
217
- ]
218
- },
219
- {
220
- "cell_type": "markdown",
221
- "id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
222
- "metadata": {},
223
- "source": [
224
- "# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
225
- "This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
226
- ]
227
- },
228
- {
229
- "cell_type": "code",
230
- "execution_count": null,
231
- "id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
232
- "metadata": {},
233
- "outputs": [],
234
- "source": [
235
- "from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": null,
241
- "id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
242
- "metadata": {},
243
- "outputs": [],
244
- "source": [
245
- "# Define paths\n",
246
- "model_directory = \"/path/to/model/save/path\"\n",
247
- "input_data_file = \"/path/to/input/data.dataset\"\n",
248
- "output_directory = \"/path/to/output/directory\"\n",
249
- "output_prefix = \"mtl_quantized_perturbation\"\n",
250
- "\n",
251
- "# Define parameters\n",
252
- "perturb_type = \"delete\" # or \"overexpress\"\n",
253
- "\n",
254
- "# Define cell states to model\n",
255
- "cell_states_to_model = {\n",
256
- " \"state_key\": \"disease_state\", \n",
257
- " \"start_state\": \"disease\", \n",
258
- " \"goal_state\": \"control\"\n",
259
- "}\n",
260
- "\n",
261
- "# Define filter data\n",
262
- "filter_data_dict = {\n",
263
- " \"cell_type\": [\"Fibroblast\"]\n",
264
- "}"
265
- ]
266
- },
267
- {
268
- "cell_type": "markdown",
269
- "id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
270
- "metadata": {},
271
- "source": [
272
- "## 3. Extract State Embeddings\n",
273
- "\n",
274
- "Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
275
- ]
276
- },
277
- {
278
- "cell_type": "code",
279
- "execution_count": null,
280
- "id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
281
- "metadata": {},
282
- "outputs": [],
283
- "source": [
284
- "# Initialize EmbExtractor\n",
285
- "embex = EmbExtractor(\n",
286
- " filter_data_dict=filter_data_dict,\n",
287
- " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
- " emb_layer=0, # Use the second to last layer\n",
289
- " emb_mode = \"cls\",\n",
290
- " summary_stat=\"exact_mean\",\n",
291
- " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
- " nproc=4\n",
293
- ")\n",
294
- "\n",
295
- "# Extract state embeddings\n",
296
- "state_embs_dict = embex.get_state_embs(\n",
297
- " cell_states_to_model,\n",
298
- " model_directory=model_directory,\n",
299
- " input_data_file=input_data_file,\n",
300
- " output_directory=output_directory,\n",
301
- " output_prefix=output_prefix\n",
302
- ")"
303
- ]
304
- },
305
- {
306
- "cell_type": "markdown",
307
- "id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
308
- "metadata": {},
309
- "source": [
310
- "## 4. Initialize the InSilicoPerturber\n",
311
- "\n",
312
- "Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
313
- ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "id": "09f985a1-91bc-4e8d-8001-a3663531b570",
319
- "metadata": {},
320
- "outputs": [],
321
- "source": [
322
- "# Initialize InSilicoPerturber\n",
323
- "isp = InSilicoPerturber(\n",
324
- " perturb_type=perturb_type,\n",
325
- " genes_to_perturb=\"all\", # Perturb all genes\n",
326
- " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
- " emb_mode=\"cls\", # Use CLS token embedding\n",
328
- " cell_states_to_model=cell_states_to_model,\n",
329
- " state_embs_dict=state_embs_dict,\n",
330
- " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
331
- " emb_layer=0, \n",
332
- " forward_batch_size=8, # Adjust based on available GPU memory\n",
333
- " nproc=1\n",
334
- ")"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
340
- "metadata": {},
341
- "source": [
342
- "## 5. Run In Silico Perturbation\n",
343
- "\n",
344
- "Run the in silico perturbation on the dataset."
345
- ]
346
- },
347
- {
348
- "cell_type": "code",
349
- "execution_count": null,
350
- "id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
351
- "metadata": {},
352
- "outputs": [],
353
- "source": [
354
- "# Run perturbation and output intermediate files\n",
355
- "isp.perturb_data(\n",
356
- " model_directory=model_directory,\n",
357
- " input_data_file=input_data_file,\n",
358
- " output_directory=output_directory,\n",
359
- " output_prefix=output_prefix\n",
360
- ")"
361
- ]
362
- },
363
- {
364
- "cell_type": "markdown",
365
- "id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
366
- "metadata": {},
367
- "source": [
368
- "## 6. Process Results with InSilicoPerturberStats\n",
369
- "\n",
370
- "After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
371
- ]
372
- },
373
- {
374
- "cell_type": "code",
375
- "execution_count": null,
376
- "id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
377
- "metadata": {},
378
- "outputs": [],
379
- "source": [
380
- "# Initialize InSilicoPerturberStats\n",
381
- "ispstats = InSilicoPerturberStats(\n",
382
- " mode=\"goal_state_shift\",\n",
383
- " genes_perturbed=\"all\",\n",
384
- " combos=0,\n",
385
- " anchor_gene=None,\n",
386
- " cell_states_to_model=cell_states_to_model\n",
387
- ")\n",
388
- "\n",
389
- "# Process stats and output final .csv\n",
390
- "ispstats.get_stats(\n",
391
- " input_data_file,\n",
392
- " None,\n",
393
- " output_directory,\n",
394
- " output_prefix\n",
395
- ")"
396
- ]
397
- }
398
- ],
399
- "metadata": {
400
- "kernelspec": {
401
- "display_name": "Python 3 (ipykernel)",
402
- "language": "python",
403
- "name": "python3"
404
- },
405
- "language_info": {
406
- "codemirror_mode": {
407
- "name": "ipython",
408
- "version": 3
409
- },
410
- "file_extension": ".py",
411
- "mimetype": "text/x-python",
412
- "name": "python",
413
- "nbconvert_exporter": "python",
414
- "pygments_lexer": "ipython3",
415
- "version": "3.11.5"
416
- }
417
- },
418
- "nbformat": 4,
419
- "nbformat_minor": 5
420
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py CHANGED
@@ -138,9 +138,7 @@ training_args = {
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
  "save_strategy": "steps",
141
- "save_steps": np.floor(
142
- num_examples / geneformer_batch_size / 8
143
- ), # 8 saves per epoch
144
  "logging_steps": 1000,
145
  "output_dir": training_output_dir,
146
  "logging_dir": logging_dir,
 
138
  "per_device_train_batch_size": geneformer_batch_size,
139
  "num_train_epochs": epochs,
140
  "save_strategy": "steps",
141
+ "save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
 
 
142
  "logging_steps": 1000,
143
  "output_dir": training_output_dir,
144
  "logging_dir": logging_dir,
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -7,39 +7,23 @@
7
  "tags": []
8
  },
9
  "source": [
10
- "## Tokenizing .loom or .h5ad single cell RNA-seq data to rank value encoding .dataset format"
11
  ]
12
  },
13
  {
14
  "cell_type": "markdown",
15
- "id": "1fe86f48-5578-47df-b373-58c21ec170ab",
16
  "metadata": {},
17
  "source": [
18
- "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
19
  "\n",
20
- "#### The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.\n",
21
- "\n",
22
- "#### Genes should be labeled with Ensembl IDs (loom row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute \"n_counts\") to be used for normalization.\n",
23
  "\n",
24
  "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
- "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
29
- ]
30
- },
31
- {
32
- "cell_type": "markdown",
33
- "id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
34
- "metadata": {},
35
- "source": [
36
- "**********************************************************************************************************\n",
37
- "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
38
- "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
39
- "\n",
40
- "#### ADDITIONALLY:\n",
41
- "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
42
- "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
43
  ]
44
  },
45
  {
@@ -59,11 +43,8 @@
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
- "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
63
- "tk.tokenize_data(\"loom_data_directory\", \n",
64
- " \"output_directory\", \n",
65
- " \"output_prefix\", \n",
66
- " file_format=\"loom\")"
67
  ]
68
  }
69
  ],
@@ -83,7 +64,7 @@
83
  "name": "python",
84
  "nbconvert_exporter": "python",
85
  "pygments_lexer": "ipython3",
86
- "version": "3.10.15"
87
  }
88
  },
89
  "nbformat": 4,
 
7
  "tags": []
8
  },
9
  "source": [
10
+ "## Tokenizing .loom single cell RNA-seq data to rank value encoding .dataset format"
11
  ]
12
  },
13
  {
14
  "cell_type": "markdown",
15
+ "id": "350e6252-b783-494b-9767-f087eb868a15",
16
  "metadata": {},
17
  "source": [
18
+ "#### Input data is a directory with .loom files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. \n",
19
  "\n",
20
+ "#### Genes should be labeled with Ensembl IDs (row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (column attribute \"n_counts\") to be used for normalization.\n",
 
 
21
  "\n",
22
  "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
23
  "\n",
24
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
25
  "\n",
26
+ "#### If one's data is in other formats besides .loom, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom format prior to running the transcriptome tokenizer."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ]
28
  },
29
  {
 
43
  "metadata": {},
44
  "outputs": [],
45
  "source": [
46
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
47
+ "tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\")"
 
 
 
48
  ]
49
  }
50
  ],
 
64
  "name": "python",
65
  "nbconvert_exporter": "python",
66
  "pygments_lexer": "ipython3",
67
+ "version": "3.10.11"
68
  }
69
  },
70
  "nbformat": 4,
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json RENAMED
File without changes
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin RENAMED
File without changes
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.2",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
- size 152363342
 
 
 
 
{gf-12L-30M-i2048 → geneformer-12L-30M}/config.json RENAMED
File without changes
{gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin RENAMED
File without changes
{gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin RENAMED
File without changes
geneformer/__init__.py CHANGED
@@ -1,34 +1,12 @@
1
- # ruff: noqa: F401
2
- import warnings
3
- from pathlib import Path
4
-
5
- warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
-
7
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
11
-
12
- from . import (
13
- collator_for_classification,
14
- emb_extractor,
15
- in_silico_perturber,
16
- in_silico_perturber_stats,
17
- pretrainer,
18
- tokenizer,
19
- )
20
- from .collator_for_classification import (
21
- DataCollatorForCellClassification,
22
- DataCollatorForGeneClassification,
23
- )
24
- from .emb_extractor import EmbExtractor, get_embs
25
- from .in_silico_perturber import InSilicoPerturber
26
- from .in_silico_perturber_stats import InSilicoPerturberStats
27
- from .pretrainer import GeneformerPretrainer
28
  from .tokenizer import TranscriptomeTokenizer
29
-
30
- from . import classifier # noqa # isort:skip
31
- from .classifier import Classifier # noqa # isort:skip
32
-
33
- from . import mtl_classifier # noqa # isort:skip
34
- from .mtl_classifier import MTLClassifier # noqa # isort:skip
 
1
+ from . import tokenizer
2
+ from . import pretrainer
3
+ from . import collator_for_classification
4
+ from . import in_silico_perturber
5
+ from . import in_silico_perturber_stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from .tokenizer import TranscriptomeTokenizer
7
+ from .pretrainer import GeneformerPretrainer
8
+ from .collator_for_classification import DataCollatorForGeneClassification
9
+ from .collator_for_classification import DataCollatorForCellClassification
10
+ from .emb_extractor import EmbExtractor
11
+ from .in_silico_perturber import InSilicoPerturber
12
+ from .in_silico_perturber_stats import InSilicoPerturberStats
geneformer/classifier.py DELETED
@@ -1,1563 +0,0 @@
1
- """
2
- Geneformer classifier.
3
-
4
- **Input data:**
5
-
6
- | Cell state classifier:
7
- | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
8
-
9
- | Gene classifier:
10
- | Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
11
-
12
- **Usage:**
13
-
14
- .. code-block :: python
15
-
16
- >>> from geneformer import Classifier
17
- >>> cc = Classifier(classifier="cell", # example of cell state classifier
18
- ... cell_state_dict={"state_key": "disease", "states": "all"},
19
- ... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
20
- ... training_args=training_args,
21
- ... freeze_layers = 2,
22
- ... num_crossval_splits = 1,
23
- ... forward_batch_size=200,
24
- ... nproc=16)
25
- >>> cc.prepare_data(input_data_file="path/to/input_data",
26
- ... output_directory="path/to/output_directory",
27
- ... output_prefix="output_prefix")
28
- >>> all_metrics = cc.validate(model_directory="path/to/model",
29
- ... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
30
- ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
31
- ... output_directory="path/to/output_directory",
32
- ... output_prefix="output_prefix",
33
- ... predict_eval=True)
34
- >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
35
- ... output_directory="path/to/output_directory",
36
- ... output_prefix="output_prefix",
37
- ... custom_class_order=["healthy","disease1","disease2"])
38
- >>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
39
- ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
40
- ... title="disease",
41
- ... output_directory="path/to/output_directory",
42
- ... output_prefix="output_prefix",
43
- ... custom_class_order=["healthy","disease1","disease2"])
44
- """
45
-
46
- import datetime
47
- import logging
48
- import os
49
- import pickle
50
- import subprocess
51
- from pathlib import Path
52
-
53
- import numpy as np
54
- import pandas as pd
55
- import seaborn as sns
56
- from tqdm.auto import tqdm, trange
57
- from transformers import Trainer
58
- from transformers.training_args import TrainingArguments
59
-
60
- from . import (
61
- TOKEN_DICTIONARY_FILE,
62
- DataCollatorForCellClassification,
63
- DataCollatorForGeneClassification,
64
- )
65
- from . import classifier_utils as cu
66
- from . import evaluation_utils as eu
67
- from . import perturber_utils as pu
68
-
69
- sns.set()
70
-
71
-
72
- logger = logging.getLogger(__name__)
73
-
74
-
75
- class Classifier:
76
- valid_option_dict = {
77
- "classifier": {"cell", "gene"},
78
- "quantize": {bool, dict},
79
- "cell_state_dict": {None, dict},
80
- "gene_class_dict": {None, dict},
81
- "filter_data": {None, dict},
82
- "rare_threshold": {int, float},
83
- "max_ncells": {None, int},
84
- "max_ncells_per_class": {None, int},
85
- "training_args": {None, dict},
86
- "freeze_layers": {int},
87
- "num_crossval_splits": {0, 1, 5},
88
- "split_sizes": {None, dict},
89
- "no_eval": {bool},
90
- "stratify_splits_col": {None, str},
91
- "forward_batch_size": {int},
92
- "token_dictionary_file": {None, str},
93
- "nproc": {int},
94
- "ngpu": {int},
95
- }
96
-
97
- def __init__(
98
- self,
99
- classifier=None,
100
- quantize=False,
101
- cell_state_dict=None,
102
- gene_class_dict=None,
103
- filter_data=None,
104
- rare_threshold=0,
105
- max_ncells=None,
106
- max_ncells_per_class=None,
107
- training_args=None,
108
- ray_config=None,
109
- freeze_layers=0,
110
- num_crossval_splits=1,
111
- split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
112
- stratify_splits_col=None,
113
- no_eval=False,
114
- forward_batch_size=100,
115
- token_dictionary_file=None,
116
- nproc=4,
117
- ngpu=1,
118
- ):
119
- """
120
- Initialize Geneformer classifier.
121
-
122
- **Parameters:**
123
-
124
- classifier : {"cell", "gene"}
125
- | Whether to fine-tune a cell state or gene classifier.
126
- quantize : bool, dict
127
- | Whether to fine-tune a quantized model.
128
- | If True and no config provided, will use default.
129
- | Will use custom config if provided.
130
- | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
131
- | For example: {"bnb_config": BitsAndBytesConfig(...),
132
- | "peft_config": LoraConfig(...)}
133
- cell_state_dict : None, dict
134
- | Cell states to fine-tune model to distinguish.
135
- | Two-item dictionary with keys: state_key and states
136
- | state_key: key specifying name of column in .dataset that defines the states to model
137
- | states: list of values in the state_key column that specifies the states to model
138
- | Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
139
- | Of note, if using "all", states will be defined after data is filtered.
140
- | Must have at least 2 states to model.
141
- | For example: {"state_key": "disease",
142
- | "states": ["nf", "hcm", "dcm"]}
143
- | or
144
- | {"state_key": "disease",
145
- | "states": "all"}
146
- gene_class_dict : None, dict
147
- | Gene classes to fine-tune model to distinguish.
148
- | Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
149
- | Gene_label_B: list(geneB1, geneB2, ...)}
150
- | Gene values should be Ensembl IDs.
151
- filter_data : None, dict
152
- | Default is to fine-tune with all input data.
153
- | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
154
- rare_threshold : float
155
- | Threshold below which rare cell states should be removed.
156
- | For example, setting to 0.05 will remove cell states representing
157
- | < 5% of the total cells from the cell state classifier's possible classes.
158
- max_ncells : None, int
159
- | Maximum number of cells to use for fine-tuning.
160
- | Default is to fine-tune with all input data.
161
- max_ncells_per_class : None, int
162
- | Maximum number of cells per cell class to use for fine-tuning.
163
- | Of note, will be applied after max_ncells above.
164
- | (Only valid for cell classification.)
165
- training_args : None, dict
166
- | Training arguments for fine-tuning.
167
- | If None, defaults will be inferred for 6 layer Geneformer.
168
- | Otherwise, will use the Hugging Face defaults:
169
- | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
170
- | Note: Hyperparameter tuning is highly recommended, rather than using defaults.
171
- ray_config : None, dict
172
- | Training argument ranges for tuning hyperparameters with Ray.
173
- freeze_layers : int
174
- | Number of layers to freeze from fine-tuning.
175
- | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
176
- num_crossval_splits : {0, 1, 5}
177
- | 0: train on all data without splitting
178
- | 1: split data into train and eval sets by designated split_sizes["valid"]
179
- | 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
180
- split_sizes : None, dict
181
- | Dictionary of proportion of data to hold out for train, validation, and test sets
182
- | {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
183
- stratify_splits_col : None, str
184
- | Name of column in .dataset to be used for stratified splitting.
185
- | Proportion of each class in this column will be the same in the splits as in the original dataset.
186
- no_eval : bool
187
- | If True, will skip eval step and use all data for training.
188
- | Otherwise, will perform eval during training.
189
- forward_batch_size : int
190
- | Batch size for forward pass (for evaluation, not training).
191
- token_dictionary_file : None, str
192
- | Default is to use token dictionary file from Geneformer
193
- | Otherwise, will load custom gene token dictionary.
194
- nproc : int
195
- | Number of CPU processes to use.
196
- ngpu : int
197
- | Number of GPUs available.
198
-
199
- """
200
-
201
- self.classifier = classifier
202
- if self.classifier == "cell":
203
- self.model_type = "CellClassifier"
204
- elif self.classifier == "gene":
205
- self.model_type = "GeneClassifier"
206
- self.quantize = quantize
207
- self.cell_state_dict = cell_state_dict
208
- self.gene_class_dict = gene_class_dict
209
- self.filter_data = filter_data
210
- self.rare_threshold = rare_threshold
211
- self.max_ncells = max_ncells
212
- self.max_ncells_per_class = max_ncells_per_class
213
- self.training_args = training_args
214
- self.ray_config = ray_config
215
- self.freeze_layers = freeze_layers
216
- self.num_crossval_splits = num_crossval_splits
217
- self.split_sizes = split_sizes
218
- self.train_size = self.split_sizes["train"]
219
- self.valid_size = self.split_sizes["valid"]
220
- self.oos_test_size = self.split_sizes["test"]
221
- self.eval_size = self.valid_size / (self.train_size + self.valid_size)
222
- self.stratify_splits_col = stratify_splits_col
223
- self.no_eval = no_eval
224
- self.forward_batch_size = forward_batch_size
225
- self.token_dictionary_file = token_dictionary_file
226
- self.nproc = nproc
227
- self.ngpu = ngpu
228
-
229
- if self.training_args is None:
230
- logger.warning(
231
- "Hyperparameter tuning is highly recommended for optimal results. "
232
- "No training_args provided; using default hyperparameters."
233
- )
234
-
235
- self.validate_options()
236
-
237
- if self.filter_data is None:
238
- self.filter_data = dict()
239
-
240
- if self.classifier == "cell":
241
- if self.cell_state_dict["states"] != "all":
242
- self.filter_data[
243
- self.cell_state_dict["state_key"]
244
- ] = self.cell_state_dict["states"]
245
-
246
- # load token dictionary (Ensembl IDs:token)
247
- if self.token_dictionary_file is None:
248
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
249
- with open(self.token_dictionary_file, "rb") as f:
250
- self.gene_token_dict = pickle.load(f)
251
-
252
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
253
-
254
- # filter genes for gene classification for those in token dictionary
255
- if self.classifier == "gene":
256
- all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
257
- missing_genes = [
258
- gene
259
- for gene in all_gene_class_values
260
- if gene not in self.gene_token_dict.keys()
261
- ]
262
- if len(missing_genes) == len(all_gene_class_values):
263
- logger.error(
264
- "None of the provided genes to classify are in token dictionary."
265
- )
266
- raise
267
- elif len(missing_genes) > 0:
268
- logger.warning(
269
- f"Genes to classify {missing_genes} are not in token dictionary."
270
- )
271
- self.gene_class_dict = {
272
- k: list(set([self.gene_token_dict.get(gene) for gene in v]))
273
- for k, v in self.gene_class_dict.items()
274
- }
275
- empty_classes = []
276
- for k, v in self.gene_class_dict.items():
277
- if len(v) == 0:
278
- empty_classes += [k]
279
- if len(empty_classes) > 0:
280
- logger.error(
281
- f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
282
- )
283
- raise
284
-
285
- def validate_options(self):
286
- # confirm arguments are within valid options and compatible with each other
287
- for attr_name, valid_options in self.valid_option_dict.items():
288
- attr_value = self.__dict__[attr_name]
289
- if not isinstance(attr_value, (list, dict)):
290
- if attr_value in valid_options:
291
- continue
292
- valid_type = False
293
- for option in valid_options:
294
- if (option in [int, float, list, dict, bool, str]) and isinstance(
295
- attr_value, option
296
- ):
297
- valid_type = True
298
- break
299
- if valid_type:
300
- continue
301
- logger.error(
302
- f"Invalid option for {attr_name}. "
303
- f"Valid options for {attr_name}: {valid_options}"
304
- )
305
- raise
306
-
307
- if self.filter_data is not None:
308
- for key, value in self.filter_data.items():
309
- if not isinstance(value, list):
310
- self.filter_data[key] = [value]
311
- logger.warning(
312
- "Values in filter_data dict must be lists. "
313
- f"Changing {key} value to list ([{value}])."
314
- )
315
-
316
- if self.classifier == "cell":
317
- if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
318
- logger.error(
319
- "Invalid keys for cell_state_dict. "
320
- "The cell_state_dict should have only 2 keys: state_key and states"
321
- )
322
- raise
323
-
324
- if self.cell_state_dict["states"] != "all":
325
- if not isinstance(self.cell_state_dict["states"], list):
326
- logger.error(
327
- "States in cell_state_dict should be list of states to model."
328
- )
329
- raise
330
- if len(self.cell_state_dict["states"]) < 2:
331
- logger.error(
332
- "States in cell_state_dict should contain at least 2 states to classify."
333
- )
334
- raise
335
-
336
- if self.classifier == "gene":
337
- if len(self.gene_class_dict.keys()) < 2:
338
- logger.error(
339
- "Gene_class_dict should contain at least 2 gene classes to classify."
340
- )
341
- raise
342
- if sum(self.split_sizes.values()) != 1:
343
- logger.error("Train, validation, and test proportions should sum to 1.")
344
- raise
345
-
346
- def prepare_data(
347
- self,
348
- input_data_file,
349
- output_directory,
350
- output_prefix,
351
- split_id_dict=None,
352
- test_size=None,
353
- attr_to_split=None,
354
- attr_to_balance=None,
355
- max_trials=100,
356
- pval_threshold=0.1,
357
- ):
358
- """
359
- Prepare data for cell state or gene classification.
360
-
361
- **Parameters**
362
-
363
- input_data_file : Path
364
- | Path to directory containing .dataset input
365
- output_directory : Path
366
- | Path to directory where prepared data will be saved
367
- output_prefix : str
368
- | Prefix for output file
369
- split_id_dict : None, dict
370
- | Dictionary of IDs for train and test splits
371
- | Three-item dictionary with keys: attr_key, train, test
372
- | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
373
- | train: list of IDs in the attr_key column to include in the train split
374
- | test: list of IDs in the attr_key column to include in the test split
375
- | For example: {"attr_key": "individual",
376
- | "train": ["patient1", "patient2", "patient3", "patient4"],
377
- | "test": ["patient5", "patient6"]}
378
- test_size : None, float
379
- | Proportion of data to be saved separately and held out for test set
380
- | (e.g. 0.2 if intending hold out 20%)
381
- | If None, will inherit from split_sizes["test"] from Classifier
382
- | The training set will be further split to train / validation in self.validate
383
- | Note: only available for CellClassifiers
384
- attr_to_split : None, str
385
- | Key for attribute on which to split data while balancing potential confounders
386
- | e.g. "patient_id" for splitting by patient while balancing other characteristics
387
- | Note: only available for CellClassifiers
388
- attr_to_balance : None, list
389
- | List of attribute keys on which to balance data while splitting on attr_to_split
390
- | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
391
- | Note: only available for CellClassifiers
392
- max_trials : None, int
393
- | Maximum number of trials of random splitting to try to achieve balanced other attributes
394
- | If no split is found without significant (p<0.05) differences in other attributes, will select best
395
- | Note: only available for CellClassifiers
396
- pval_threshold : None, float
397
- | P-value threshold to use for attribute balancing across splits
398
- | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
399
- """
400
-
401
- if test_size is None:
402
- test_size = self.oos_test_size
403
-
404
- # prepare data and labels for classification
405
- data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
406
-
407
- if self.classifier == "cell":
408
- if "label" in data.features:
409
- logger.error(
410
- "Column name 'label' must be reserved for class IDs. Please rename column."
411
- )
412
- raise
413
- elif self.classifier == "gene":
414
- if "labels" in data.features:
415
- logger.error(
416
- "Column name 'labels' must be reserved for class IDs. Please rename column."
417
- )
418
- raise
419
-
420
- if (attr_to_split is not None) and (attr_to_balance is None):
421
- logger.error(
422
- "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
423
- )
424
- raise
425
-
426
- if not isinstance(attr_to_balance, list):
427
- attr_to_balance = [attr_to_balance]
428
-
429
- if self.classifier == "cell":
430
- # remove cell states representing < rare_threshold of cells
431
- data = cu.remove_rare(
432
- data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
433
- )
434
- # downsample max cells and max per class
435
- data = cu.downsample_and_shuffle(
436
- data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
437
- )
438
- # rename cell state column to "label"
439
- data = cu.rename_cols(data, self.cell_state_dict["state_key"])
440
-
441
- # convert classes to numerical labels and save as id_class_dict
442
- # of note, will label all genes in gene_class_dict
443
- # if (cross-)validating, genes will be relabeled in column "labels" for each split
444
- # at the time of training with Classifier.validate
445
- data, id_class_dict = cu.label_classes(
446
- self.classifier, data, self.gene_class_dict, self.nproc
447
- )
448
-
449
- # save id_class_dict for future reference
450
- id_class_output_path = (
451
- Path(output_directory) / f"{output_prefix}_id_class_dict"
452
- ).with_suffix(".pkl")
453
- with open(id_class_output_path, "wb") as f:
454
- pickle.dump(id_class_dict, f)
455
-
456
- if split_id_dict is not None:
457
- data_dict = dict()
458
- data_dict["train"] = pu.filter_by_dict(
459
- data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
460
- )
461
- data_dict["test"] = pu.filter_by_dict(
462
- data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
463
- )
464
- train_data_output_path = (
465
- Path(output_directory) / f"{output_prefix}_labeled_train"
466
- ).with_suffix(".dataset")
467
- test_data_output_path = (
468
- Path(output_directory) / f"{output_prefix}_labeled_test"
469
- ).with_suffix(".dataset")
470
- data_dict["train"].save_to_disk(str(train_data_output_path))
471
- data_dict["test"].save_to_disk(str(test_data_output_path))
472
- elif (test_size is not None) and (self.classifier == "cell"):
473
- if 1 > test_size > 0:
474
- if attr_to_split is None:
475
- data_dict = data.train_test_split(
476
- test_size=test_size,
477
- stratify_by_column=self.stratify_splits_col,
478
- seed=42,
479
- )
480
- train_data_output_path = (
481
- Path(output_directory) / f"{output_prefix}_labeled_train"
482
- ).with_suffix(".dataset")
483
- test_data_output_path = (
484
- Path(output_directory) / f"{output_prefix}_labeled_test"
485
- ).with_suffix(".dataset")
486
- data_dict["train"].save_to_disk(str(train_data_output_path))
487
- data_dict["test"].save_to_disk(str(test_data_output_path))
488
- else:
489
- data_dict, balance_df = cu.balance_attr_splits(
490
- data,
491
- attr_to_split,
492
- attr_to_balance,
493
- test_size,
494
- max_trials,
495
- pval_threshold,
496
- self.cell_state_dict["state_key"],
497
- self.nproc,
498
- )
499
- balance_df.to_csv(
500
- f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
501
- )
502
- train_data_output_path = (
503
- Path(output_directory) / f"{output_prefix}_labeled_train"
504
- ).with_suffix(".dataset")
505
- test_data_output_path = (
506
- Path(output_directory) / f"{output_prefix}_labeled_test"
507
- ).with_suffix(".dataset")
508
- data_dict["train"].save_to_disk(str(train_data_output_path))
509
- data_dict["test"].save_to_disk(str(test_data_output_path))
510
- else:
511
- data_output_path = (
512
- Path(output_directory) / f"{output_prefix}_labeled"
513
- ).with_suffix(".dataset")
514
- data.save_to_disk(str(data_output_path))
515
- print(data_output_path)
516
- else:
517
- data_output_path = (
518
- Path(output_directory) / f"{output_prefix}_labeled"
519
- ).with_suffix(".dataset")
520
- data.save_to_disk(str(data_output_path))
521
-
522
- def train_all_data(
523
- self,
524
- model_directory,
525
- prepared_input_data_file,
526
- id_class_dict_file,
527
- output_directory,
528
- output_prefix,
529
- save_eval_output=True,
530
- gene_balance=False,
531
- ):
532
- """
533
- Train cell state or gene classifier using all data.
534
-
535
- **Parameters**
536
-
537
- model_directory : Path
538
- | Path to directory containing model
539
- prepared_input_data_file : Path
540
- | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
541
- id_class_dict_file : Path
542
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
543
- | (dictionary of format: numerical IDs: class_labels)
544
- output_directory : Path
545
- | Path to directory where model and eval data will be saved
546
- output_prefix : str
547
- | Prefix for output files
548
- save_eval_output : bool
549
- | Whether to save cross-fold eval output
550
- | Saves as pickle file of dictionary of eval metrics
551
- gene_balance : None, bool
552
- | Whether to automatically balance genes in training set.
553
- | Only available for binary gene classifications.
554
-
555
- **Output**
556
-
557
- Returns trainer after fine-tuning with all data.
558
-
559
- """
560
-
561
- if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
562
- logger.error(
563
- "Automatically balancing gene sets for training is only available for binary gene classifications."
564
- )
565
- raise
566
-
567
- ##### Load data and prepare output directory #####
568
- # load numerical id to class dictionary (id:class)
569
- with open(id_class_dict_file, "rb") as f:
570
- id_class_dict = pickle.load(f)
571
- class_id_dict = {v: k for k, v in id_class_dict.items()}
572
-
573
- # load previously filtered and prepared data
574
- data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
575
- data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
576
-
577
- # define output directory path
578
- current_date = datetime.datetime.now()
579
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
580
- if output_directory[-1:] != "/": # add slash for dir if not present
581
- output_directory = output_directory + "/"
582
- output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
583
- subprocess.call(f"mkdir {output_dir}", shell=True)
584
-
585
- # get number of classes for classifier
586
- num_classes = cu.get_num_classes(id_class_dict)
587
-
588
- if self.classifier == "gene":
589
- targets = pu.flatten_list(self.gene_class_dict.values())
590
- labels = pu.flatten_list(
591
- [
592
- [class_id_dict[label]] * len(targets)
593
- for label, targets in self.gene_class_dict.items()
594
- ]
595
- )
596
- assert len(targets) == len(labels)
597
- data = cu.prep_gene_classifier_all_data(
598
- data, targets, labels, self.max_ncells, self.nproc, gene_balance
599
- )
600
-
601
- trainer = self.train_classifier(
602
- model_directory, num_classes, data, None, output_dir
603
- )
604
-
605
- return trainer
606
-
607
- def validate(
608
- self,
609
- model_directory,
610
- prepared_input_data_file,
611
- id_class_dict_file,
612
- output_directory,
613
- output_prefix,
614
- split_id_dict=None,
615
- attr_to_split=None,
616
- attr_to_balance=None,
617
- gene_balance=False,
618
- max_trials=100,
619
- pval_threshold=0.1,
620
- save_eval_output=True,
621
- predict_eval=True,
622
- predict_trainer=False,
623
- n_hyperopt_trials=0,
624
- save_gene_split_datasets=True,
625
- debug_gene_split_datasets=False,
626
- ):
627
- """
628
- (Cross-)validate cell state or gene classifier.
629
-
630
- **Parameters**
631
-
632
- model_directory : Path
633
- | Path to directory containing model
634
- prepared_input_data_file : Path
635
- | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
636
- id_class_dict_file : Path
637
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
638
- | (dictionary of format: numerical IDs: class_labels)
639
- output_directory : Path
640
- | Path to directory where model and eval data will be saved
641
- output_prefix : str
642
- | Prefix for output files
643
- split_id_dict : None, dict
644
- | Dictionary of IDs for train and eval splits
645
- | Three-item dictionary with keys: attr_key, train, eval
646
- | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
647
- | train: list of IDs in the attr_key column to include in the train split
648
- | eval: list of IDs in the attr_key column to include in the eval split
649
- | For example: {"attr_key": "individual",
650
- | "train": ["patient1", "patient2", "patient3", "patient4"],
651
- | "eval": ["patient5", "patient6"]}
652
- | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
653
- attr_to_split : None, str
654
- | Key for attribute on which to split data while balancing potential confounders
655
- | e.g. "patient_id" for splitting by patient while balancing other characteristics
656
- | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
657
- attr_to_balance : None, list
658
- | List of attribute keys on which to balance data while splitting on attr_to_split
659
- | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
660
- gene_balance : None, bool
661
- | Whether to automatically balance genes in training set.
662
- | Only available for binary gene classifications.
663
- max_trials : None, int
664
- | Maximum number of trials of random splitting to try to achieve balanced other attribute
665
- | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
666
- pval_threshold : None, float
667
- | P-value threshold to use for attribute balancing across splits
668
- | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
669
- save_eval_output : bool
670
- | Whether to save cross-fold eval output
671
- | Saves as pickle file of dictionary of eval metrics
672
- predict_eval : bool
673
- | Whether or not to save eval predictions
674
- | Saves as a pickle file of self.evaluate predictions
675
- predict_trainer : bool
676
- | Whether or not to save eval predictions from trainer
677
- | Saves as a pickle file of trainer predictions
678
- n_hyperopt_trials : int
679
- | Number of trials to run for hyperparameter optimization
680
- | If 0, will not optimize hyperparameters
681
- save_gene_split_datasets : bool
682
- | Whether or not to save train, valid, and test gene-labeled datasets
683
- """
684
- if self.num_crossval_splits == 0:
685
- logger.error("num_crossval_splits must be 1 or 5 to validate.")
686
- raise
687
-
688
- if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
689
- logger.error(
690
- "Automatically balancing gene sets for training is only available for binary gene classifications."
691
- )
692
- raise
693
-
694
- # ensure number of genes in each class is > 5 if validating model
695
- if self.classifier == "gene":
696
- insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
697
- if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
698
- logger.error(
699
- f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
700
- )
701
- raise
702
-
703
- ##### Load data and prepare output directory #####
704
- # load numerical id to class dictionary (id:class)
705
- with open(id_class_dict_file, "rb") as f:
706
- id_class_dict = pickle.load(f)
707
- class_id_dict = {v: k for k, v in id_class_dict.items()}
708
-
709
- # load previously filtered and prepared data
710
- data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
711
- data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
712
-
713
- # define output directory path
714
- current_date = datetime.datetime.now()
715
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
716
- if output_directory[-1:] != "/": # add slash for dir if not present
717
- output_directory = output_directory + "/"
718
- output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
719
- subprocess.call(f"mkdir {output_dir}", shell=True)
720
-
721
- # get number of classes for classifier
722
- num_classes = cu.get_num_classes(id_class_dict)
723
-
724
- ##### (Cross-)validate the model #####
725
- results = []
726
- all_conf_mat = np.zeros((num_classes, num_classes))
727
- iteration_num = 1
728
- if self.classifier == "cell":
729
- for i in trange(self.num_crossval_splits):
730
- print(
731
- f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
732
- )
733
- ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
734
- if self.num_crossval_splits == 1:
735
- # single 1-eval_size:eval_size split
736
- if split_id_dict is not None:
737
- data_dict = dict()
738
- data_dict["train"] = pu.filter_by_dict(
739
- data,
740
- {split_id_dict["attr_key"]: split_id_dict["train"]},
741
- self.nproc,
742
- )
743
- data_dict["test"] = pu.filter_by_dict(
744
- data,
745
- {split_id_dict["attr_key"]: split_id_dict["eval"]},
746
- self.nproc,
747
- )
748
- elif attr_to_split is not None:
749
- data_dict, balance_df = cu.balance_attr_splits(
750
- data,
751
- attr_to_split,
752
- attr_to_balance,
753
- self.eval_size,
754
- max_trials,
755
- pval_threshold,
756
- self.cell_state_dict["state_key"],
757
- self.nproc,
758
- )
759
-
760
- balance_df.to_csv(
761
- f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
762
- )
763
- else:
764
- data_dict = data.train_test_split(
765
- test_size=self.eval_size,
766
- stratify_by_column=self.stratify_splits_col,
767
- seed=42,
768
- )
769
- train_data = data_dict["train"]
770
- eval_data = data_dict["test"]
771
- else:
772
- # 5-fold cross-validate
773
- num_cells = len(data)
774
- fifth_cells = int(np.floor(num_cells * 0.2))
775
- num_eval = min((self.eval_size * num_cells), fifth_cells)
776
- start = i * fifth_cells
777
- end = start + num_eval
778
- eval_indices = [j for j in range(start, end)]
779
- train_indices = [
780
- j for j in range(num_cells) if j not in eval_indices
781
- ]
782
- eval_data = data.select(eval_indices)
783
- train_data = data.select(train_indices)
784
- if n_hyperopt_trials == 0:
785
- trainer = self.train_classifier(
786
- model_directory,
787
- num_classes,
788
- train_data,
789
- eval_data,
790
- ksplit_output_dir,
791
- predict_trainer,
792
- )
793
- else:
794
- trainer = self.hyperopt_classifier(
795
- model_directory,
796
- num_classes,
797
- train_data,
798
- eval_data,
799
- ksplit_output_dir,
800
- n_trials=n_hyperopt_trials,
801
- )
802
- if iteration_num == self.num_crossval_splits:
803
- return
804
- else:
805
- iteration_num = iteration_num + 1
806
- continue
807
-
808
- result = self.evaluate_model(
809
- trainer.model,
810
- num_classes,
811
- id_class_dict,
812
- eval_data,
813
- predict_eval,
814
- ksplit_output_dir,
815
- output_prefix,
816
- )
817
- results += [result]
818
- all_conf_mat = all_conf_mat + result["conf_mat"]
819
- iteration_num = iteration_num + 1
820
-
821
- elif self.classifier == "gene":
822
- # set up (cross-)validation splits
823
- targets = pu.flatten_list(self.gene_class_dict.values())
824
- labels = pu.flatten_list(
825
- [
826
- [class_id_dict[label]] * len(targets)
827
- for label, targets in self.gene_class_dict.items()
828
- ]
829
- )
830
- assert len(targets) == len(labels)
831
- n_splits = int(1 / (1 - self.train_size))
832
- skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
833
- # (Cross-)validate
834
- test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
835
- for train_index, eval_index, test_index in tqdm(
836
- skf.split(targets, labels, test_ratio)
837
- ):
838
- print(
839
- f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
840
- )
841
- ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
842
- # filter data for examples containing classes for this split
843
- # subsample to max_ncells and relabel data in column "labels"
844
- train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
845
- data,
846
- targets,
847
- labels,
848
- train_index,
849
- eval_index,
850
- self.max_ncells,
851
- iteration_num,
852
- self.nproc,
853
- gene_balance,
854
- )
855
-
856
- if save_gene_split_datasets is True:
857
- for split_name in ["train", "valid"]:
858
- labeled_dataset_output_path = (
859
- Path(output_dir)
860
- / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
861
- ).with_suffix(".dataset")
862
- if split_name == "train":
863
- train_data.save_to_disk(str(labeled_dataset_output_path))
864
- elif split_name == "valid":
865
- eval_data.save_to_disk(str(labeled_dataset_output_path))
866
-
867
- if self.oos_test_size > 0:
868
- test_data = cu.prep_gene_classifier_split(
869
- data,
870
- targets,
871
- labels,
872
- test_index,
873
- "test",
874
- self.max_ncells,
875
- iteration_num,
876
- self.nproc,
877
- )
878
- if save_gene_split_datasets is True:
879
- test_labeled_dataset_output_path = (
880
- Path(output_dir)
881
- / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
882
- ).with_suffix(".dataset")
883
- test_data.save_to_disk(str(test_labeled_dataset_output_path))
884
- if debug_gene_split_datasets is True:
885
- logger.error(
886
- "Exiting after saving gene split datasets given debug_gene_split_datasets = True."
887
- )
888
- raise
889
- if n_hyperopt_trials == 0:
890
- trainer = self.train_classifier(
891
- model_directory,
892
- num_classes,
893
- train_data,
894
- eval_data,
895
- ksplit_output_dir,
896
- predict_trainer,
897
- )
898
- result = self.evaluate_model(
899
- trainer.model,
900
- num_classes,
901
- id_class_dict,
902
- eval_data,
903
- predict_eval,
904
- ksplit_output_dir,
905
- output_prefix,
906
- )
907
- else:
908
- trainer = self.hyperopt_classifier(
909
- model_directory,
910
- num_classes,
911
- train_data,
912
- eval_data,
913
- ksplit_output_dir,
914
- n_trials=n_hyperopt_trials,
915
- )
916
-
917
- model = cu.load_best_model(
918
- ksplit_output_dir, self.model_type, num_classes
919
- )
920
-
921
- if self.oos_test_size > 0:
922
- result = self.evaluate_model(
923
- model,
924
- num_classes,
925
- id_class_dict,
926
- test_data,
927
- predict_eval,
928
- ksplit_output_dir,
929
- output_prefix,
930
- )
931
- else:
932
- if iteration_num == self.num_crossval_splits:
933
- return
934
- else:
935
- iteration_num = iteration_num + 1
936
- continue
937
- results += [result]
938
- all_conf_mat = all_conf_mat + result["conf_mat"]
939
- # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
940
- if iteration_num == self.num_crossval_splits:
941
- break
942
- iteration_num = iteration_num + 1
943
-
944
- all_conf_mat_df = pd.DataFrame(
945
- all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
946
- )
947
- all_metrics = {
948
- "conf_matrix": all_conf_mat_df,
949
- "macro_f1": [result["macro_f1"] for result in results],
950
- "acc": [result["acc"] for result in results],
951
- }
952
- all_roc_metrics = None # roc metrics not reported for multiclass
953
- if num_classes == 2:
954
- mean_fpr = np.linspace(0, 1, 100)
955
- all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
956
- all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
957
- all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
958
- mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
959
- all_tpr, all_roc_auc, all_tpr_wt
960
- )
961
- all_roc_metrics = {
962
- "mean_tpr": mean_tpr,
963
- "mean_fpr": mean_fpr,
964
- "all_roc_auc": all_roc_auc,
965
- "roc_auc": roc_auc,
966
- "roc_auc_sd": roc_auc_sd,
967
- }
968
- all_metrics["all_roc_metrics"] = all_roc_metrics
969
- if save_eval_output is True:
970
- eval_metrics_output_path = (
971
- Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
972
- ).with_suffix(".pkl")
973
- with open(eval_metrics_output_path, "wb") as f:
974
- pickle.dump(all_metrics, f)
975
-
976
- return all_metrics
977
-
978
- def hyperopt_classifier(
979
- self,
980
- model_directory,
981
- num_classes,
982
- train_data,
983
- eval_data,
984
- output_directory,
985
- n_trials=100,
986
- ):
987
- """
988
- Fine-tune model for cell state or gene classification.
989
-
990
- **Parameters**
991
-
992
- model_directory : Path
993
- | Path to directory containing model
994
- num_classes : int
995
- | Number of classes for classifier
996
- train_data : Dataset
997
- | Loaded training .dataset input
998
- | For cell classifier, labels in column "label".
999
- | For gene classifier, labels in column "labels".
1000
- eval_data : None, Dataset
1001
- | (Optional) Loaded evaluation .dataset input
1002
- | For cell classifier, labels in column "label".
1003
- | For gene classifier, labels in column "labels".
1004
- output_directory : Path
1005
- | Path to directory where fine-tuned model will be saved
1006
- n_trials : int
1007
- | Number of trials to run for hyperparameter optimization
1008
- """
1009
-
1010
- # initiate runtime environment for raytune
1011
- import ray
1012
- from ray import tune
1013
- from ray.tune.search.hyperopt import HyperOptSearch
1014
-
1015
- ray.shutdown() # engage new ray session
1016
- ray.init()
1017
-
1018
- ##### Validate and prepare data #####
1019
- train_data, eval_data = cu.validate_and_clean_cols(
1020
- train_data, eval_data, self.classifier
1021
- )
1022
-
1023
- if (self.no_eval is True) and (eval_data is not None):
1024
- logger.warning(
1025
- "no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
1026
- )
1027
-
1028
- # ensure not overwriting previously saved model
1029
- saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
1030
- if os.path.isfile(saved_model_test) is True:
1031
- logger.error("Model already saved to this designated output directory.")
1032
- raise
1033
- # make output directory
1034
- subprocess.call(f"mkdir {output_directory}", shell=True)
1035
-
1036
- ##### Load model and training args #####
1037
- model = pu.load_model(
1038
- self.model_type,
1039
- num_classes,
1040
- model_directory,
1041
- "train",
1042
- quantize=self.quantize,
1043
- )
1044
- def_training_args, def_freeze_layers = cu.get_default_train_args(
1045
- model, self.classifier, train_data, output_directory
1046
- )
1047
- del model
1048
-
1049
- if self.training_args is not None:
1050
- def_training_args.update(self.training_args)
1051
- logging_steps = round(
1052
- len(train_data) / def_training_args["per_device_train_batch_size"] / 10
1053
- )
1054
- def_training_args["logging_steps"] = logging_steps
1055
- def_training_args["output_dir"] = output_directory
1056
- if eval_data is None:
1057
- def_training_args["evaluation_strategy"] = "no"
1058
- def_training_args["load_best_model_at_end"] = False
1059
- def_training_args.update(
1060
- {"save_strategy": "epoch", "save_total_limit": 1}
1061
- ) # only save last model for each run
1062
- training_args_init = TrainingArguments(**def_training_args)
1063
-
1064
- ##### Fine-tune the model #####
1065
- # define the data collator
1066
- if self.classifier == "cell":
1067
- data_collator = DataCollatorForCellClassification(
1068
- token_dictionary=self.gene_token_dict
1069
- )
1070
- elif self.classifier == "gene":
1071
- data_collator = DataCollatorForGeneClassification(
1072
- token_dictionary=self.gene_token_dict
1073
- )
1074
-
1075
- # define function to initiate model
1076
- def model_init():
1077
- model = pu.load_model(
1078
- self.model_type,
1079
- num_classes,
1080
- model_directory,
1081
- "train",
1082
- quantize=self.quantize,
1083
- )
1084
-
1085
- if self.freeze_layers is not None:
1086
- def_freeze_layers = self.freeze_layers
1087
-
1088
- if def_freeze_layers > 0:
1089
- modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
1090
- for module in modules_to_freeze:
1091
- for param in module.parameters():
1092
- param.requires_grad = False
1093
-
1094
- if self.quantize is False:
1095
- model = model.to("cuda:0")
1096
- return model
1097
-
1098
- # create the trainer
1099
- trainer = Trainer(
1100
- model_init=model_init,
1101
- args=training_args_init,
1102
- data_collator=data_collator,
1103
- train_dataset=train_data,
1104
- eval_dataset=eval_data,
1105
- compute_metrics=cu.compute_metrics,
1106
- )
1107
-
1108
- # specify raytune hyperparameter search space
1109
- if self.ray_config is None:
1110
- logger.warning(
1111
- "No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
1112
- )
1113
- def_ray_config = {
1114
- "num_train_epochs": tune.choice([1]),
1115
- "learning_rate": tune.loguniform(1e-6, 1e-3),
1116
- "weight_decay": tune.uniform(0.0, 0.3),
1117
- "lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
1118
- "warmup_steps": tune.uniform(100, 2000),
1119
- "seed": tune.uniform(0, 100),
1120
- "per_device_train_batch_size": tune.choice(
1121
- [def_training_args["per_device_train_batch_size"]]
1122
- ),
1123
- }
1124
-
1125
- hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
1126
-
1127
- # optimize hyperparameters
1128
- trainer.hyperparameter_search(
1129
- direction="maximize",
1130
- backend="ray",
1131
- resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
1132
- hp_space=lambda _: def_ray_config
1133
- if self.ray_config is None
1134
- else self.ray_config,
1135
- search_alg=hyperopt_search,
1136
- n_trials=n_trials, # number of trials
1137
- progress_reporter=tune.CLIReporter(
1138
- max_report_frequency=600,
1139
- sort_by_metric=True,
1140
- max_progress_rows=n_trials,
1141
- mode="max",
1142
- metric="eval_macro_f1",
1143
- metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1144
- ),
1145
- storage_path=output_directory,
1146
- )
1147
-
1148
- return trainer
1149
-
1150
- def train_classifier(
1151
- self,
1152
- model_directory,
1153
- num_classes,
1154
- train_data,
1155
- eval_data,
1156
- output_directory,
1157
- predict=False,
1158
- ):
1159
- """
1160
- Fine-tune model for cell state or gene classification.
1161
-
1162
- **Parameters**
1163
-
1164
- model_directory : Path
1165
- | Path to directory containing model
1166
- num_classes : int
1167
- | Number of classes for classifier
1168
- train_data : Dataset
1169
- | Loaded training .dataset input
1170
- | For cell classifier, labels in column "label".
1171
- | For gene classifier, labels in column "labels".
1172
- eval_data : None, Dataset
1173
- | (Optional) Loaded evaluation .dataset input
1174
- | For cell classifier, labels in column "label".
1175
- | For gene classifier, labels in column "labels".
1176
- output_directory : Path
1177
- | Path to directory where fine-tuned model will be saved
1178
- predict : bool
1179
- | Whether or not to save eval predictions from trainer
1180
- """
1181
-
1182
- ##### Validate and prepare data #####
1183
- train_data, eval_data = cu.validate_and_clean_cols(
1184
- train_data, eval_data, self.classifier
1185
- )
1186
-
1187
- if (self.no_eval is True) and (eval_data is not None):
1188
- logger.warning(
1189
- "no_eval set to True; model will be trained without evaluation."
1190
- )
1191
- eval_data = None
1192
-
1193
- if (self.classifier == "gene") and (predict is True):
1194
- logger.warning(
1195
- "Predictions during training not currently available for gene classifiers; setting predict to False."
1196
- )
1197
- predict = False
1198
-
1199
- # ensure not overwriting previously saved model
1200
- saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
1201
- if os.path.isfile(saved_model_test) is True:
1202
- logger.error("Model already saved to this designated output directory.")
1203
- raise
1204
- # make output directory
1205
- subprocess.call(f"mkdir {output_directory}", shell=True)
1206
-
1207
- ##### Load model and training args #####
1208
- model = pu.load_model(
1209
- self.model_type,
1210
- num_classes,
1211
- model_directory,
1212
- "train",
1213
- quantize=self.quantize,
1214
- )
1215
-
1216
- def_training_args, def_freeze_layers = cu.get_default_train_args(
1217
- model, self.classifier, train_data, output_directory
1218
- )
1219
-
1220
- if self.training_args is not None:
1221
- def_training_args.update(self.training_args)
1222
- logging_steps = round(
1223
- len(train_data) / def_training_args["per_device_train_batch_size"] / 10
1224
- )
1225
- def_training_args["logging_steps"] = logging_steps
1226
- def_training_args["output_dir"] = output_directory
1227
- if eval_data is None:
1228
- def_training_args["evaluation_strategy"] = "no"
1229
- def_training_args["load_best_model_at_end"] = False
1230
- training_args_init = TrainingArguments(**def_training_args)
1231
-
1232
- if self.freeze_layers is not None:
1233
- def_freeze_layers = self.freeze_layers
1234
-
1235
- if def_freeze_layers > 0:
1236
- modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
1237
- for module in modules_to_freeze:
1238
- for param in module.parameters():
1239
- param.requires_grad = False
1240
-
1241
- ##### Fine-tune the model #####
1242
- # define the data collator
1243
- if self.classifier == "cell":
1244
- data_collator = DataCollatorForCellClassification(
1245
- token_dictionary=self.gene_token_dict
1246
- )
1247
- elif self.classifier == "gene":
1248
- data_collator = DataCollatorForGeneClassification(
1249
- token_dictionary=self.gene_token_dict
1250
- )
1251
-
1252
- # create the trainer
1253
- trainer = Trainer(
1254
- model=model,
1255
- args=training_args_init,
1256
- data_collator=data_collator,
1257
- train_dataset=train_data,
1258
- eval_dataset=eval_data,
1259
- compute_metrics=cu.compute_metrics,
1260
- )
1261
-
1262
- # train the classifier
1263
- trainer.train()
1264
- trainer.save_model(output_directory)
1265
- if predict is True:
1266
- # make eval predictions and save predictions and metrics
1267
- predictions = trainer.predict(eval_data)
1268
- prediction_output_path = f"{output_directory}/predictions.pkl"
1269
- with open(prediction_output_path, "wb") as f:
1270
- pickle.dump(predictions, f)
1271
- trainer.save_metrics("eval", predictions.metrics)
1272
- return trainer
1273
-
1274
- def evaluate_model(
1275
- self,
1276
- model,
1277
- num_classes,
1278
- id_class_dict,
1279
- eval_data,
1280
- predict=False,
1281
- output_directory=None,
1282
- output_prefix=None,
1283
- ):
1284
- """
1285
- Evaluate the fine-tuned model.
1286
-
1287
- **Parameters**
1288
-
1289
- model : nn.Module
1290
- | Loaded fine-tuned model (e.g. trainer.model)
1291
- num_classes : int
1292
- | Number of classes for classifier
1293
- id_class_dict : dict
1294
- | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
1295
- | (dictionary of format: numerical IDs: class_labels)
1296
- eval_data : Dataset
1297
- | Loaded evaluation .dataset input
1298
- predict : bool
1299
- | Whether or not to save eval predictions
1300
- output_directory : Path
1301
- | Path to directory where eval data will be saved
1302
- output_prefix : str
1303
- | Prefix for output files
1304
- """
1305
-
1306
- ##### Evaluate the model #####
1307
- labels = id_class_dict.keys()
1308
- y_pred, y_true, logits_list = eu.classifier_predict(
1309
- model, self.classifier, eval_data, self.forward_batch_size
1310
- )
1311
- conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1312
- y_pred, y_true, logits_list, num_classes, labels
1313
- )
1314
- if predict is True:
1315
- pred_dict = {
1316
- "pred_ids": y_pred,
1317
- "label_ids": y_true,
1318
- "predictions": logits_list,
1319
- }
1320
- pred_dict_output_path = (
1321
- Path(output_directory) / f"{output_prefix}_pred_dict"
1322
- ).with_suffix(".pkl")
1323
- with open(pred_dict_output_path, "wb") as f:
1324
- pickle.dump(pred_dict, f)
1325
- return {
1326
- "conf_mat": conf_mat,
1327
- "macro_f1": macro_f1,
1328
- "acc": acc,
1329
- "roc_metrics": roc_metrics,
1330
- }
1331
-
1332
- def evaluate_saved_model(
1333
- self,
1334
- model_directory,
1335
- id_class_dict_file,
1336
- test_data_file,
1337
- output_directory,
1338
- output_prefix,
1339
- predict=True,
1340
- ):
1341
- """
1342
- Evaluate the fine-tuned model.
1343
-
1344
- **Parameters**
1345
-
1346
- model_directory : Path
1347
- | Path to directory containing model
1348
- id_class_dict_file : Path
1349
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1350
- | (dictionary of format: numerical IDs: class_labels)
1351
- test_data_file : Path
1352
- | Path to directory containing test .dataset
1353
- output_directory : Path
1354
- | Path to directory where eval data will be saved
1355
- output_prefix : str
1356
- | Prefix for output files
1357
- predict : bool
1358
- | Whether or not to save eval predictions
1359
- """
1360
-
1361
- # load numerical id to class dictionary (id:class)
1362
- with open(id_class_dict_file, "rb") as f:
1363
- id_class_dict = pickle.load(f)
1364
-
1365
- # get number of classes for classifier
1366
- num_classes = cu.get_num_classes(id_class_dict)
1367
-
1368
- # load previously filtered and prepared data
1369
- test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1370
-
1371
- # load previously fine-tuned model
1372
- model = pu.load_model(
1373
- self.model_type,
1374
- num_classes,
1375
- model_directory,
1376
- "eval",
1377
- quantize=self.quantize,
1378
- )
1379
-
1380
- # evaluate the model
1381
- result = self.evaluate_model(
1382
- model,
1383
- num_classes,
1384
- id_class_dict,
1385
- test_data,
1386
- predict=predict,
1387
- output_directory=output_directory,
1388
- output_prefix=output_prefix,
1389
- )
1390
-
1391
- all_conf_mat_df = pd.DataFrame(
1392
- result["conf_mat"],
1393
- columns=id_class_dict.values(),
1394
- index=id_class_dict.values(),
1395
- )
1396
- all_metrics = {
1397
- "conf_matrix": all_conf_mat_df,
1398
- "macro_f1": result["macro_f1"],
1399
- "acc": result["acc"],
1400
- }
1401
- all_roc_metrics = None # roc metrics not reported for multiclass
1402
-
1403
- if num_classes == 2:
1404
- mean_fpr = np.linspace(0, 1, 100)
1405
- mean_tpr = result["roc_metrics"]["interp_tpr"]
1406
- all_roc_auc = result["roc_metrics"]["auc"]
1407
- all_roc_metrics = {
1408
- "mean_tpr": mean_tpr,
1409
- "mean_fpr": mean_fpr,
1410
- "all_roc_auc": all_roc_auc,
1411
- }
1412
- all_metrics["all_roc_metrics"] = all_roc_metrics
1413
- test_metrics_output_path = (
1414
- Path(output_directory) / f"{output_prefix}_test_metrics_dict"
1415
- ).with_suffix(".pkl")
1416
- with open(test_metrics_output_path, "wb") as f:
1417
- pickle.dump(all_metrics, f)
1418
-
1419
- return all_metrics
1420
-
1421
- def plot_conf_mat(
1422
- self,
1423
- conf_mat_dict,
1424
- output_directory,
1425
- output_prefix,
1426
- custom_class_order=None,
1427
- ):
1428
- """
1429
- Plot confusion matrix results of evaluating the fine-tuned model.
1430
-
1431
- **Parameters**
1432
-
1433
- conf_mat_dict : dict
1434
- | Dictionary of model_name : confusion_matrix_DataFrame
1435
- | (all_metrics["conf_matrix"] from self.validate)
1436
- output_directory : Path
1437
- | Path to directory where plots will be saved
1438
- output_prefix : str
1439
- | Prefix for output file
1440
- custom_class_order : None, list
1441
- | List of classes in custom order for plots.
1442
- | Same order will be used for all models.
1443
- """
1444
-
1445
- for model_name in conf_mat_dict.keys():
1446
- eu.plot_confusion_matrix(
1447
- conf_mat_dict[model_name],
1448
- model_name,
1449
- output_directory,
1450
- output_prefix,
1451
- custom_class_order,
1452
- )
1453
-
1454
- def plot_roc(
1455
- self,
1456
- roc_metric_dict,
1457
- model_style_dict,
1458
- title,
1459
- output_directory,
1460
- output_prefix,
1461
- ):
1462
- """
1463
- Plot ROC curve results of evaluating the fine-tuned model.
1464
-
1465
- **Parameters**
1466
-
1467
- roc_metric_dict : dict
1468
- | Dictionary of model_name : roc_metrics
1469
- | (all_metrics["all_roc_metrics"] from self.validate)
1470
- model_style_dict : dict[dict]
1471
- | Dictionary of model_name : dictionary of style_attribute : style
1472
- | where style includes color and linestyle
1473
- | e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
1474
- title : str
1475
- | Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
1476
- output_directory : Path
1477
- | Path to directory where plots will be saved
1478
- output_prefix : str
1479
- | Prefix for output file
1480
- """
1481
-
1482
- eu.plot_ROC(
1483
- roc_metric_dict, model_style_dict, title, output_directory, output_prefix
1484
- )
1485
-
1486
- def plot_predictions(
1487
- self,
1488
- predictions_file,
1489
- id_class_dict_file,
1490
- title,
1491
- output_directory,
1492
- output_prefix,
1493
- custom_class_order=None,
1494
- kwargs_dict=None,
1495
- ):
1496
- """
1497
- Plot prediction results of evaluating the fine-tuned model.
1498
-
1499
- **Parameters**
1500
-
1501
- predictions_file : path
1502
- | Path of model predictions output to plot
1503
- | (saved output from self.validate if predict_eval=True)
1504
- | (or saved output from self.evaluate_saved_model)
1505
- id_class_dict_file : Path
1506
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1507
- | (dictionary of format: numerical IDs: class_labels)
1508
- title : str
1509
- | Title for legend containing class labels.
1510
- output_directory : Path
1511
- | Path to directory where plots will be saved
1512
- output_prefix : str
1513
- | Prefix for output file
1514
- custom_class_order : None, list
1515
- | List of classes in custom order for plots.
1516
- | Same order will be used for all models.
1517
- kwargs_dict : None, dict
1518
- | Dictionary of kwargs to pass to plotting function.
1519
- """
1520
- # load predictions
1521
- with open(predictions_file, "rb") as f:
1522
- predictions = pickle.load(f)
1523
-
1524
- # load numerical id to class dictionary (id:class)
1525
- with open(id_class_dict_file, "rb") as f:
1526
- id_class_dict = pickle.load(f)
1527
-
1528
- if isinstance(predictions, dict):
1529
- if all(
1530
- [
1531
- key in predictions.keys()
1532
- for key in ["pred_ids", "label_ids", "predictions"]
1533
- ]
1534
- ):
1535
- # format is output from self.evaluate_saved_model
1536
- predictions_logits = np.array(predictions["predictions"])
1537
- true_ids = predictions["label_ids"]
1538
- else:
1539
- # format is output from self.validate if predict_eval=True
1540
- predictions_logits = predictions.predictions
1541
- true_ids = predictions.label_ids
1542
-
1543
- num_classes = len(id_class_dict.keys())
1544
- num_predict_classes = predictions_logits.shape[1]
1545
- assert num_classes == num_predict_classes
1546
- classes = id_class_dict.values()
1547
- true_labels = [id_class_dict[idx] for idx in true_ids]
1548
- predictions_df = pd.DataFrame(predictions_logits, columns=classes)
1549
- if custom_class_order is not None:
1550
- predictions_df = predictions_df.reindex(columns=custom_class_order)
1551
- predictions_df["true"] = true_labels
1552
- custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
1553
- if custom_class_order is not None:
1554
- custom_dict = dict(
1555
- zip(custom_class_order, [i for i in range(len(custom_class_order))])
1556
- )
1557
- predictions_df = predictions_df.sort_values(
1558
- by=["true"], key=lambda x: x.map(custom_dict)
1559
- )
1560
-
1561
- eu.plot_predictions(
1562
- predictions_df, title, output_directory, output_prefix, kwargs_dict
1563
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/classifier_utils.py DELETED
@@ -1,648 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import random
5
- from collections import Counter, defaultdict
6
-
7
- import numpy as np
8
- import pandas as pd
9
- from scipy.stats import chisquare, ranksums
10
- from sklearn.metrics import accuracy_score, f1_score
11
- from sklearn.model_selection import StratifiedKFold, train_test_split
12
-
13
- from . import perturber_utils as pu
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
19
- data = data.shuffle(seed=42)
20
- num_cells = len(data)
21
- # if max number of cells is defined, then subsample to this max number
22
- if max_ncells is not None:
23
- if num_cells > max_ncells:
24
- data = data.select([i for i in range(max_ncells)])
25
- if max_ncells_per_class is not None:
26
- class_labels = data[cell_state_dict["state_key"]]
27
- random.seed(42)
28
- subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
29
- data = data.select(subsample_indices)
30
- return data
31
-
32
-
33
- # subsample labels to maximum number N per class and return indices
34
- def subsample_by_class(labels, N):
35
- label_indices = defaultdict(list)
36
- # Gather indices for each label
37
- for idx, label in enumerate(labels):
38
- label_indices[label].append(idx)
39
- selected_indices = []
40
- # Select up to N indices for each label
41
- for label, indices in label_indices.items():
42
- if len(indices) > N:
43
- selected_indices.extend(random.sample(indices, N))
44
- else:
45
- selected_indices.extend(indices)
46
- return selected_indices
47
-
48
-
49
- def rename_cols(data, state_key):
50
- data = data.rename_column(state_key, "label")
51
- return data
52
-
53
-
54
- def validate_and_clean_cols(train_data, eval_data, classifier):
55
- # validate that data has expected label column and remove others
56
- if classifier == "cell":
57
- label_col = "label"
58
- elif classifier == "gene":
59
- label_col = "labels"
60
-
61
- cols_to_keep = [label_col] + ["input_ids", "length"]
62
- if label_col not in train_data.column_names:
63
- logger.error(f"train_data must contain column {label_col} with class labels.")
64
- raise
65
- else:
66
- train_data = remove_cols(train_data, cols_to_keep)
67
-
68
- if eval_data is not None:
69
- if label_col not in eval_data.column_names:
70
- logger.error(
71
- f"eval_data must contain column {label_col} with class labels."
72
- )
73
- raise
74
- else:
75
- eval_data = remove_cols(eval_data, cols_to_keep)
76
- return train_data, eval_data
77
-
78
-
79
- def remove_cols(data, cols_to_keep):
80
- other_cols = list(data.features.keys())
81
- other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
82
- data = data.remove_columns(other_cols)
83
- return data
84
-
85
-
86
- def remove_rare(data, rare_threshold, label, nproc):
87
- if rare_threshold > 0:
88
- total_cells = len(data)
89
- label_counter = Counter(data[label])
90
- nonrare_label_dict = {
91
- label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
92
- }
93
- data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
94
- return data
95
-
96
-
97
- def label_classes(classifier, data, gene_class_dict, nproc):
98
- if classifier == "cell":
99
- label_set = set(data["label"])
100
- elif classifier == "gene":
101
- # remove cells without any of the target genes
102
- def if_contains_label(example):
103
- a = pu.flatten_list(gene_class_dict.values())
104
- b = example["input_ids"]
105
- return not set(a).isdisjoint(b)
106
-
107
- data = data.filter(if_contains_label, num_proc=nproc)
108
- label_set = gene_class_dict.keys()
109
-
110
- if len(data) == 0:
111
- logger.error(
112
- "No cells remain after filtering for target genes. Check target gene list."
113
- )
114
- raise
115
-
116
- class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
- id_class_dict = {v: k for k, v in class_id_dict.items()}
118
-
119
- def classes_to_ids(example):
120
- if classifier == "cell":
121
- example["label"] = class_id_dict[example["label"]]
122
- elif classifier == "gene":
123
- example["labels"] = label_gene_classes(
124
- example, class_id_dict, gene_class_dict
125
- )
126
- return example
127
-
128
- data = data.map(classes_to_ids, num_proc=nproc)
129
- return data, id_class_dict
130
-
131
-
132
- def label_gene_classes(example, class_id_dict, gene_class_dict):
133
- return [
134
- class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
135
- for token_id in example["input_ids"]
136
- ]
137
-
138
-
139
- def prep_gene_classifier_train_eval_split(
140
- data,
141
- targets,
142
- labels,
143
- train_index,
144
- eval_index,
145
- max_ncells,
146
- iteration_num,
147
- num_proc,
148
- balance=False,
149
- ):
150
- # generate cross-validation splits
151
- train_data = prep_gene_classifier_split(
152
- data,
153
- targets,
154
- labels,
155
- train_index,
156
- "train",
157
- max_ncells,
158
- iteration_num,
159
- num_proc,
160
- balance,
161
- )
162
- eval_data = prep_gene_classifier_split(
163
- data,
164
- targets,
165
- labels,
166
- eval_index,
167
- "eval",
168
- max_ncells,
169
- iteration_num,
170
- num_proc,
171
- balance,
172
- )
173
- return train_data, eval_data
174
-
175
-
176
- def prep_gene_classifier_split(
177
- data,
178
- targets,
179
- labels,
180
- index,
181
- subset_name,
182
- max_ncells,
183
- iteration_num,
184
- num_proc,
185
- balance=False,
186
- ):
187
- # generate cross-validation splits
188
- targets = np.array(targets)
189
- labels = np.array(labels)
190
- targets_subset = targets[index]
191
- labels_subset = labels[index]
192
- label_dict_subset = dict(zip(targets_subset, labels_subset))
193
-
194
- # function to filter by whether contains train or eval labels
195
- def if_contains_subset_label(example):
196
- a = targets_subset
197
- b = example["input_ids"]
198
- return not set(a).isdisjoint(b)
199
-
200
- # filter dataset for examples containing classes for this split
201
- logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
202
- subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
203
- logger.info(
204
- f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
205
- )
206
-
207
- # balance gene subsets if train
208
- if (subset_name == "train") and (balance is True):
209
- subset_data, label_dict_subset = balance_gene_split(
210
- subset_data, label_dict_subset, num_proc
211
- )
212
-
213
- # subsample to max_ncells
214
- subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
215
-
216
- # relabel genes for this split
217
- def subset_classes_to_ids(example):
218
- example["labels"] = [
219
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
220
- ]
221
- return example
222
-
223
- subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
224
-
225
- return subset_data
226
-
227
-
228
- def prep_gene_classifier_all_data(
229
- data, targets, labels, max_ncells, num_proc, balance=False
230
- ):
231
- targets = np.array(targets)
232
- labels = np.array(labels)
233
- label_dict_train = dict(zip(targets, labels))
234
-
235
- # function to filter by whether contains train labels
236
- def if_contains_train_label(example):
237
- a = targets
238
- b = example["input_ids"]
239
- return not set(a).isdisjoint(b)
240
-
241
- # filter dataset for examples containing classes for this split
242
- logger.info("Filtering training data for genes to classify.")
243
- train_data = data.filter(if_contains_train_label, num_proc=num_proc)
244
- logger.info(
245
- f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
246
- )
247
-
248
- if balance is True:
249
- train_data, label_dict_train = balance_gene_split(
250
- train_data, label_dict_train, num_proc
251
- )
252
-
253
- # subsample to max_ncells
254
- train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
255
-
256
- # relabel genes for this split
257
- def train_classes_to_ids(example):
258
- example["labels"] = [
259
- label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
260
- ]
261
- return example
262
-
263
- train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
264
-
265
- return train_data
266
-
267
-
268
- def balance_gene_split(subset_data, label_dict_subset, num_proc):
269
- # count occurrence of genes in each label category
270
- label0_counts, label1_counts = count_genes_for_balancing(
271
- subset_data, label_dict_subset, num_proc
272
- )
273
- label_ratio_0to1 = label0_counts / label1_counts
274
-
275
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
276
- # gene sets already balanced
277
- logger.info(
278
- "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
279
- )
280
- return subset_data, label_dict_subset
281
- else:
282
- label_ratio_0to1_orig = label_ratio_0to1 + 0
283
- label_dict_subset_orig = label_dict_subset.copy()
284
- # balance gene sets
285
- max_ntrials = 25
286
- boost = 1
287
- if label_ratio_0to1 > 10 / 8:
288
- # downsample label 0
289
- for i in range(max_ntrials):
290
- label0 = 0
291
- label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
292
- label0_ngenes = len(label0_genes)
293
- label0_nremove = max(
294
- 1,
295
- int(
296
- np.floor(
297
- label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
298
- )
299
- ),
300
- )
301
- random.seed(i)
302
- label0_remove_genes = random.sample(label0_genes, label0_nremove)
303
- label_dict_subset_new = {
304
- k: v
305
- for k, v in label_dict_subset.items()
306
- if k not in label0_remove_genes
307
- }
308
- label0_counts, label1_counts = count_genes_for_balancing(
309
- subset_data, label_dict_subset_new, num_proc
310
- )
311
- label_ratio_0to1 = label0_counts / label1_counts
312
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
313
- # if gene sets now balanced, return new filtered data and new label_dict_subset
314
- return filter_data_balanced_genes(
315
- subset_data, label_dict_subset_new, num_proc
316
- )
317
- elif label_ratio_0to1 > 10 / 8:
318
- boost = boost * 1.1
319
- elif label_ratio_0to1 < 8 / 10:
320
- boost = boost * 0.9
321
- else:
322
- # downsample label 1
323
- for i in range(max_ntrials):
324
- label1 = 1
325
- label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
326
- label1_ngenes = len(label1_genes)
327
- label1_nremove = max(
328
- 1,
329
- int(
330
- np.floor(
331
- label1_ngenes
332
- - label1_ngenes / ((1 / label_ratio_0to1) * boost)
333
- )
334
- ),
335
- )
336
- random.seed(i)
337
- label1_remove_genes = random.sample(label1_genes, label1_nremove)
338
- label_dict_subset_new = {
339
- k: v
340
- for k, v in label_dict_subset.items()
341
- if k not in label1_remove_genes
342
- }
343
- label0_counts, label1_counts = count_genes_for_balancing(
344
- subset_data, label_dict_subset_new, num_proc
345
- )
346
- label_ratio_0to1 = label0_counts / label1_counts
347
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
348
- # if gene sets now balanced, return new filtered data and new label_dict_subset
349
- return filter_data_balanced_genes(
350
- subset_data, label_dict_subset_new, num_proc
351
- )
352
- elif label_ratio_0to1 < 8 / 10:
353
- boost = boost * 1.1
354
- elif label_ratio_0to1 > 10 / 8:
355
- boost = boost * 0.9
356
-
357
- assert i + 1 == max_ntrials
358
- if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
359
- 10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
360
- ):
361
- label_ratio_0to1 = label_ratio_0to1_orig
362
- label_dict_subset_new = label_dict_subset_orig
363
- logger.warning(
364
- f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
365
- )
366
- return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
367
-
368
-
369
- def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
370
- def count_targets(example):
371
- labels = [
372
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
373
- ]
374
- counter_labels = Counter(labels)
375
- # get count of labels 0 or 1, or if absent, return 0
376
- example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
377
- return example
378
-
379
- subset_data = subset_data.map(count_targets, num_proc=num_proc)
380
-
381
- label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
382
- label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
383
-
384
- subset_data = subset_data.remove_columns("labels_counts")
385
-
386
- return label0_counts, label1_counts
387
-
388
-
389
- def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
390
- # function to filter by whether contains labels
391
- def if_contains_subset_label(example):
392
- a = list(label_dict_subset.keys())
393
- b = example["input_ids"]
394
- return not set(a).isdisjoint(b)
395
-
396
- # filter dataset for examples containing classes for this split
397
- logger.info("Filtering data for balanced genes")
398
- subset_data_len_orig = len(subset_data)
399
- subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
400
- logger.info(
401
- f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
402
- )
403
-
404
- return subset_data, label_dict_subset
405
-
406
-
407
- def balance_attr_splits(
408
- data,
409
- attr_to_split,
410
- attr_to_balance,
411
- eval_size,
412
- max_trials,
413
- pval_threshold,
414
- state_key,
415
- nproc,
416
- ):
417
- metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
418
- for attr in attr_to_balance:
419
- if attr == state_key:
420
- metadata_df[attr] = data["label"]
421
- else:
422
- metadata_df[attr] = data[attr]
423
- metadata_df = metadata_df.drop_duplicates()
424
-
425
- split_attr_ids = list(metadata_df["split_attr_ids"])
426
- assert len(split_attr_ids) == len(set(split_attr_ids))
427
- eval_num = round(len(split_attr_ids) * eval_size)
428
- colnames = (
429
- ["trial_num", "train_ids", "eval_ids"]
430
- + pu.flatten_list(
431
- [
432
- [
433
- f"{attr}_train_mean_or_counts",
434
- f"{attr}_eval_mean_or_counts",
435
- f"{attr}_pval",
436
- ]
437
- for attr in attr_to_balance
438
- ]
439
- )
440
- + ["mean_pval"]
441
- )
442
- balance_df = pd.DataFrame(columns=colnames)
443
- data_dict = dict()
444
- trial_num = 1
445
- for i in range(max_trials):
446
- if not all(
447
- count > 1 for count in list(Counter(metadata_df[state_key]).values())
448
- ):
449
- logger.error(
450
- f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
451
- )
452
- raise
453
- eval_base = []
454
- for state in set(metadata_df[state_key]):
455
- eval_base += list(
456
- metadata_df.loc[
457
- metadata_df[state_key][metadata_df[state_key].eq(state)]
458
- .sample(1, random_state=i)
459
- .index
460
- ]["split_attr_ids"]
461
- )
462
- non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
463
- random.seed(i)
464
- eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
465
- train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
466
- df_vals = [trial_num, train_ids, eval_ids]
467
- pvals = []
468
- for attr in attr_to_balance:
469
- train_attr = list(
470
- metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
471
- )
472
- eval_attr = list(
473
- metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
474
- )
475
- if attr == state_key:
476
- # ensure IDs are interpreted as categorical
477
- train_attr = [str(item) for item in train_attr]
478
- eval_attr = [str(item) for item in eval_attr]
479
- if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
480
- train_attr_mean = np.nanmean(train_attr)
481
- eval_attr_mean = np.nanmean(eval_attr)
482
- pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
483
- df_vals += [train_attr_mean, eval_attr_mean, pval]
484
- elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
485
- obs_counts = Counter(train_attr)
486
- exp_counts = Counter(eval_attr)
487
- all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
488
- obs = [obs_counts[cat] for cat in all_categ]
489
- exp = [
490
- exp_counts[cat] * sum(obs) / sum(exp_counts.values())
491
- for cat in all_categ
492
- ]
493
- pval = chisquare(f_obs=obs, f_exp=exp).pvalue
494
- train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
495
- eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
496
- df_vals += [train_attr_counts, eval_attr_counts, pval]
497
- else:
498
- logger.error(
499
- f"Inconsistent data types in attribute {attr}. "
500
- "Cannot infer if continuous or categorical. "
501
- "Must be all numeric (continuous) or all strings (categorical) to balance."
502
- )
503
- raise
504
- pvals += [pval]
505
-
506
- df_vals += [np.nanmean(pvals)]
507
- balance_df_i = pd.DataFrame(df_vals, index=colnames).T
508
- balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
509
- valid_pvals = [
510
- pval_i
511
- for pval_i in pvals
512
- if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
513
- ]
514
- if all(i >= pval_threshold for i in valid_pvals):
515
- data_dict["train"] = pu.filter_by_dict(
516
- data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
517
- )
518
- data_dict["test"] = pu.filter_by_dict(
519
- data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
520
- )
521
- return data_dict, balance_df
522
- trial_num = trial_num + 1
523
- balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
524
- data_dict["train"] = pu.filter_by_dict(
525
- data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
526
- )
527
- data_dict["test"] = pu.filter_by_dict(
528
- data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
529
- )
530
- logger.warning(
531
- f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
532
- f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
533
- )
534
- return data_dict, balance_df
535
-
536
-
537
- def get_num_classes(id_class_dict):
538
- return len(set(id_class_dict.values()))
539
-
540
-
541
- def compute_metrics(pred):
542
- labels = pred.label_ids
543
- preds = pred.predictions.argmax(-1)
544
-
545
- # calculate accuracy and macro f1 using sklearn's function
546
- if len(labels.shape) == 1:
547
- acc = accuracy_score(labels, preds)
548
- macro_f1 = f1_score(labels, preds, average="macro")
549
- else:
550
- flat_labels = labels.flatten().tolist()
551
- flat_preds = preds.flatten().tolist()
552
- logit_label_paired = [
553
- item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
554
- ]
555
- y_pred = [item[0] for item in logit_label_paired]
556
- y_true = [item[1] for item in logit_label_paired]
557
-
558
- acc = accuracy_score(y_true, y_pred)
559
- macro_f1 = f1_score(y_true, y_pred, average="macro")
560
-
561
- return {"accuracy": acc, "macro_f1": macro_f1}
562
-
563
-
564
- def get_default_train_args(model, classifier, data, output_dir):
565
- num_layers = pu.quant_layers(model)
566
- freeze_layers = 0
567
- batch_size = 12
568
- if classifier == "cell":
569
- epochs = 10
570
- evaluation_strategy = "epoch"
571
- load_best_model_at_end = True
572
- else:
573
- epochs = 1
574
- evaluation_strategy = "no"
575
- load_best_model_at_end = False
576
-
577
- if num_layers == 6:
578
- default_training_args = {
579
- "learning_rate": 5e-5,
580
- "lr_scheduler_type": "linear",
581
- "warmup_steps": 500,
582
- "per_device_train_batch_size": batch_size,
583
- "per_device_eval_batch_size": batch_size,
584
- }
585
- else:
586
- default_training_args = {
587
- "per_device_train_batch_size": batch_size,
588
- "per_device_eval_batch_size": batch_size,
589
- }
590
-
591
- training_args = {
592
- "num_train_epochs": epochs,
593
- "do_train": True,
594
- "do_eval": True,
595
- "evaluation_strategy": evaluation_strategy,
596
- "logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
597
- "save_strategy": "epoch",
598
- "group_by_length": False,
599
- "length_column_name": "length",
600
- "disable_tqdm": False,
601
- "weight_decay": 0.001,
602
- "load_best_model_at_end": load_best_model_at_end,
603
- }
604
- training_args.update(default_training_args)
605
-
606
- return training_args, freeze_layers
607
-
608
-
609
- def load_best_model(directory, model_type, num_classes, mode="eval"):
610
- file_dict = dict()
611
- for subdir, dirs, files in os.walk(directory):
612
- for file in files:
613
- if file.endswith("result.json"):
614
- with open(f"{subdir}/{file}", "rb") as fp:
615
- result_json = json.load(fp)
616
- file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
617
- file_df = pd.DataFrame(
618
- {"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
619
- )
620
- model_superdir = (
621
- "run-"
622
- + file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
623
- .split("_objective_")[2]
624
- .split("_")[0]
625
- )
626
-
627
- for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
628
- for file in files:
629
- if file.endswith("model.safetensors"):
630
- model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
631
- return model
632
-
633
-
634
- class StratifiedKFold3(StratifiedKFold):
635
- def split(self, targets, labels, test_ratio=0.5, groups=None):
636
- s = super().split(targets, labels, groups)
637
- for train_indxs, test_indxs in s:
638
- if test_ratio == 0:
639
- yield train_indxs, test_indxs, None
640
- else:
641
- labels_test = np.array(labels)[test_indxs]
642
- valid_indxs, test_indxs = train_test_split(
643
- test_indxs,
644
- stratify=labels_test,
645
- test_size=test_ratio,
646
- random_state=0,
647
- )
648
- yield train_indxs, valid_indxs, test_indxs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/collator_for_classification.py CHANGED
@@ -1,22 +1,24 @@
1
  """
2
  Geneformer collator for gene and cell classification.
 
3
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
4
  """
5
-
 
6
  import warnings
7
  from enum import Enum
8
  from typing import Dict, List, Optional, Union
9
 
10
- import numpy as np
11
- import torch
12
  from transformers import (
13
- BatchEncoding,
14
  DataCollatorForTokenClassification,
15
  SpecialTokensMixin,
 
16
  )
17
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
  from transformers.utils.generic import _is_tensorflow, _is_torch
19
 
 
 
20
  EncodedInput = List[int]
21
  logger = logging.get_logger(__name__)
22
  VERY_LARGE_INTEGER = int(
@@ -28,7 +30,6 @@ LARGE_INTEGER = int(
28
 
29
  # precollator functions
30
 
31
-
32
  class ExplicitEnum(Enum):
33
  """
34
  Enum with more explicit error message for missing values.
@@ -41,7 +42,6 @@ class ExplicitEnum(Enum):
41
  % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
42
  )
43
 
44
-
45
  class TruncationStrategy(ExplicitEnum):
46
  """
47
  Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
@@ -54,6 +54,7 @@ class TruncationStrategy(ExplicitEnum):
54
  DO_NOT_TRUNCATE = "do_not_truncate"
55
 
56
 
 
57
  class PaddingStrategy(ExplicitEnum):
58
  """
59
  Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
@@ -65,6 +66,7 @@ class PaddingStrategy(ExplicitEnum):
65
  DO_NOT_PAD = "do_not_pad"
66
 
67
 
 
68
  class TensorType(ExplicitEnum):
69
  """
70
  Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
@@ -76,41 +78,21 @@ class TensorType(ExplicitEnum):
76
  NUMPY = "np"
77
  JAX = "jax"
78
 
79
-
80
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
81
- def __init__(self, *args, **kwargs) -> None:
82
- super().__init__(mask_token="<mask>", pad_token="<pad>")
83
-
84
- self.token_dictionary = kwargs.get("token_dictionary")
85
- self.padding_side = "right"
86
- self.model_input_names = ["input_ids"]
87
- self._mask_token_id = self.token_dictionary.get("<mask>")
88
- self._pad_token_id = self.token_dictionary.get("<pad>")
89
- self._all_special_ids = [
90
- self.token_dictionary.get("<mask>"),
91
- self.token_dictionary.get("<pad>"),
92
- ]
93
-
94
- @property
95
- def all_special_ids(self):
96
- return self._all_special_ids
97
-
98
- @property
99
- def mask_token_id(self):
100
- return self._mask_token_id
101
-
102
- @property
103
- def pad_token_id(self):
104
- return self._pad_token_id
105
 
106
  def _get_padding_truncation_strategies(
107
- self,
108
- padding=True,
109
- truncation=False,
110
- max_length=None,
111
- pad_to_multiple_of=None,
112
- verbose=True,
113
- **kwargs,
114
  ):
115
  """
116
  Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
@@ -123,9 +105,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
123
  # If you only set max_length, it activates truncation for max_length
124
  if max_length is not None and padding is False and truncation is False:
125
  if verbose:
126
- if not self.deprecation_warnings.get(
127
- "Truncation-not-explicitly-activated", False
128
- ):
129
  logger.warning(
130
  "Truncation was not explicitly activated but `max_length` is provided a specific value, "
131
  "please use `truncation=True` to explicitly truncate examples to max length. "
@@ -153,9 +133,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
153
  padding_strategy = PaddingStrategy.MAX_LENGTH
154
  elif padding is not False:
155
  if padding is True:
156
- padding_strategy = (
157
- PaddingStrategy.LONGEST
158
- ) # Default to pad to the longest sequence in the batch
159
  elif not isinstance(padding, PaddingStrategy):
160
  padding_strategy = PaddingStrategy(padding)
161
  elif isinstance(padding, PaddingStrategy):
@@ -195,9 +173,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
195
  if padding_strategy == PaddingStrategy.MAX_LENGTH:
196
  if self.model_max_length > LARGE_INTEGER:
197
  if verbose:
198
- if not self.deprecation_warnings.get(
199
- "Asking-to-pad-to-max_length", False
200
- ):
201
  logger.warning(
202
  "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
203
  "Default to no padding."
@@ -210,24 +186,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
210
  if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
211
  if self.model_max_length > LARGE_INTEGER:
212
  if verbose:
213
- if not self.deprecation_warnings.get(
214
- "Asking-to-truncate-to-max_length", False
215
- ):
216
  logger.warning(
217
  "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
218
  "Default to no truncation."
219
  )
220
- self.deprecation_warnings[
221
- "Asking-to-truncate-to-max_length"
222
- ] = True
223
  truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
224
  else:
225
  max_length = self.model_max_length
226
 
227
  # Test if we have a padding token
228
- if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
229
- not self.pad_token or self.pad_token_id < 0
230
- ):
231
  raise ValueError(
232
  "Asking to pad but the tokenizer does not have a padding token. "
233
  "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
@@ -258,7 +228,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
258
  Dict[str, List[EncodedInput]],
259
  List[Dict[str, EncodedInput]],
260
  ],
261
- class_type, # options: "gene" or "cell"
262
  padding: Union[bool, str, PaddingStrategy] = True,
263
  max_length: Optional[int] = None,
264
  pad_to_multiple_of: Optional[int] = None,
@@ -269,23 +239,29 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
269
  """
270
  Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
271
  in the batch.
 
272
  Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
273
  ``self.pad_token_id`` and ``self.pad_token_type_id``)
 
274
  .. note::
 
275
  If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
276
  result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
277
  case of PyTorch tensors, you will lose the specific device of your tensors however.
 
278
  Args:
279
  encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
280
  Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
281
  List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
282
  List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
283
  well as in a PyTorch Dataloader collate function.
 
284
  Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
285
  see the note above for the return type.
286
  padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
287
  Select a strategy to pad the returned sequences (according to the model's padding side and padding
288
  index) among:
 
289
  * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
290
  single sequence if provided).
291
  * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
@@ -296,14 +272,17 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
296
  Maximum length of the returned list and optionally padding length (see above).
297
  pad_to_multiple_of (:obj:`int`, `optional`):
298
  If set will pad the sequence to a multiple of the provided value.
 
299
  This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
300
  >= 7.5 (Volta).
301
  return_attention_mask (:obj:`bool`, `optional`):
302
  Whether to return the attention mask. If left to the default, will return the attention mask according
303
  to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
 
304
  `What are attention masks? <../glossary.html#attention-mask>`__
305
  return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
306
  If set, will return tensors instead of list of python integers. Acceptable values are:
 
307
  * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
308
  * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
309
  * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
@@ -312,13 +291,8 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
312
  """
313
  # If we have a list of dicts, let's convert it in a dict of lists
314
  # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
315
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(
316
- encoded_inputs[0], (dict, BatchEncoding)
317
- ):
318
- encoded_inputs = {
319
- key: [example[key] for example in encoded_inputs]
320
- for key in encoded_inputs[0].keys()
321
- }
322
 
323
  # The model's main input name, usually `input_ids`, has be passed for padding
324
  if self.model_input_names[0] not in encoded_inputs:
@@ -412,7 +386,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
412
  def _pad(
413
  self,
414
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
415
- class_type, # options: "gene" or "cell"
416
  max_length: Optional[int] = None,
417
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
418
  pad_to_multiple_of: Optional[int] = None,
@@ -420,15 +394,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
420
  ) -> dict:
421
  """
422
  Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
 
423
  Args:
424
  encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
425
  max_length: maximum length of the returned list and optionally padding length (see below).
426
  Will truncate by taking into account the special tokens.
427
  padding_strategy: PaddingStrategy to use for padding.
 
428
  - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
429
  - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
430
  - PaddingStrategy.DO_NOT_PAD: Do not pad
431
  The tokenizer padding sides are defined in self.padding_side:
 
432
  - 'left': pads on the left of the sequences
433
  - 'right': pads on the right of the sequences
434
  pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
@@ -445,73 +422,46 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
445
  if padding_strategy == PaddingStrategy.LONGEST:
446
  max_length = len(required_input)
447
 
448
- if (
449
- max_length is not None
450
- and pad_to_multiple_of is not None
451
- and (max_length % pad_to_multiple_of != 0)
452
- ):
453
  max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
454
 
455
- needs_to_be_padded = (
456
- padding_strategy != PaddingStrategy.DO_NOT_PAD
457
- and len(required_input) != max_length
458
- )
459
 
460
  if needs_to_be_padded:
461
  difference = max_length - len(required_input)
462
  if self.padding_side == "right":
463
  if return_attention_mask:
464
- encoded_inputs["attention_mask"] = [1] * len(required_input) + [
465
- 0
466
- ] * difference
467
  if "token_type_ids" in encoded_inputs:
468
  encoded_inputs["token_type_ids"] = (
469
- encoded_inputs["token_type_ids"]
470
- + [self.pad_token_type_id] * difference
471
  )
472
  if "special_tokens_mask" in encoded_inputs:
473
- encoded_inputs["special_tokens_mask"] = (
474
- encoded_inputs["special_tokens_mask"] + [1] * difference
475
- )
476
- encoded_inputs[self.model_input_names[0]] = (
477
- required_input + [self.pad_token_id] * difference
478
- )
479
  if class_type == "gene":
480
- encoded_inputs["labels"] = (
481
- encoded_inputs["labels"] + [-100] * difference
482
- )
483
  elif self.padding_side == "left":
484
  if return_attention_mask:
485
- encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
486
- required_input
487
- )
488
  if "token_type_ids" in encoded_inputs:
489
- encoded_inputs["token_type_ids"] = [
490
- self.pad_token_type_id
491
- ] * difference + encoded_inputs["token_type_ids"]
492
  if "special_tokens_mask" in encoded_inputs:
493
- encoded_inputs["special_tokens_mask"] = [
494
- 1
495
- ] * difference + encoded_inputs["special_tokens_mask"]
496
- encoded_inputs[self.model_input_names[0]] = [
497
- self.pad_token_id
498
- ] * difference + required_input
499
  if class_type == "gene":
500
- encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
501
- "labels"
502
- ]
503
  else:
504
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
505
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
506
  encoded_inputs["attention_mask"] = [1] * len(required_input)
507
-
508
  return encoded_inputs
509
 
510
  def get_special_tokens_mask(
511
- self,
512
- token_ids_0: List[int],
513
- token_ids_1: Optional[List[int]] = None,
514
- already_has_special_tokens: bool = False,
515
  ) -> List[int]:
516
  """
517
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
@@ -535,15 +485,11 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
535
 
536
  all_special_ids = self.all_special_ids # cache the property
537
 
538
- special_tokens_mask = [
539
- 1 if token in all_special_ids else 0 for token in token_ids_0
540
- ]
541
 
542
  return special_tokens_mask
543
 
544
- def convert_tokens_to_ids(
545
- self, tokens: Union[str, List[str]]
546
- ) -> Union[int, List[int]]:
547
  """
548
  Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
549
  vocabulary.
@@ -567,15 +513,14 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
567
  if token is None:
568
  return None
569
 
570
- return self.token_dictionary.get(token)
571
 
572
  def __len__(self):
573
- return len(self.token_dictionary)
574
 
575
 
576
  # collator functions
577
 
578
-
579
  class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
580
  """
581
  Data collator that will dynamically pad the inputs received, as well as the labels.
@@ -601,33 +546,25 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
601
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
602
  """
603
 
 
604
  class_type = "gene"
605
  padding: Union[bool, str, PaddingStrategy] = True
606
  max_length: Optional[int] = None
607
  pad_to_multiple_of: Optional[int] = None
608
  label_pad_token_id: int = -100
609
-
610
  def __init__(self, *args, **kwargs) -> None:
611
- self.token_dictionary = kwargs.pop("token_dictionary")
612
  super().__init__(
613
- tokenizer=PrecollatorForGeneAndCellClassification(
614
- token_dictionary=self.token_dictionary
615
- ),
616
  padding=self.padding,
617
  max_length=self.max_length,
618
  pad_to_multiple_of=self.pad_to_multiple_of,
619
  label_pad_token_id=self.label_pad_token_id,
620
- *args,
621
- **kwargs,
622
- )
623
 
624
  def _prepare_batch(self, features):
625
  label_name = "label" if "label" in features[0].keys() else "labels"
626
- labels = (
627
- [feature[label_name] for feature in features]
628
- if label_name in features[0].keys()
629
- else None
630
- )
631
  batch = self.tokenizer.pad(
632
  features,
633
  class_type=self.class_type,
@@ -637,31 +574,29 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
637
  return_tensors="pt",
638
  )
639
  return batch
640
-
641
  def __call__(self, features):
642
  batch = self._prepare_batch(features)
643
 
644
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
645
  return batch
646
 
647
-
648
  class DataCollatorForCellClassification(DataCollatorForGeneClassification):
 
649
  class_type = "cell"
650
 
651
  def _prepare_batch(self, features):
 
652
  batch = super()._prepare_batch(features)
653
-
654
  # Special handling for labels.
655
  # Ensure that tensor is created with the correct type
656
  # (it should be automatically the case, but let's make sure of it.)
657
  first = features[0]
658
  if "label" in first and first["label"] is not None:
659
- label = (
660
- first["label"].item()
661
- if isinstance(first["label"], torch.Tensor)
662
- else first["label"]
663
- )
664
  dtype = torch.long if isinstance(label, int) else torch.float
665
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
666
-
667
  return batch
 
1
  """
2
  Geneformer collator for gene and cell classification.
3
+
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
+ import numpy as np
7
+ import torch
8
  import warnings
9
  from enum import Enum
10
  from typing import Dict, List, Optional, Union
11
 
 
 
12
  from transformers import (
 
13
  DataCollatorForTokenClassification,
14
  SpecialTokensMixin,
15
+ BatchEncoding,
16
  )
17
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
  from transformers.utils.generic import _is_tensorflow, _is_torch
19
 
20
+ from .pretrainer import token_dictionary
21
+
22
  EncodedInput = List[int]
23
  logger = logging.get_logger(__name__)
24
  VERY_LARGE_INTEGER = int(
 
30
 
31
  # precollator functions
32
 
 
33
  class ExplicitEnum(Enum):
34
  """
35
  Enum with more explicit error message for missing values.
 
42
  % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
43
  )
44
 
 
45
  class TruncationStrategy(ExplicitEnum):
46
  """
47
  Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
 
54
  DO_NOT_TRUNCATE = "do_not_truncate"
55
 
56
 
57
+
58
  class PaddingStrategy(ExplicitEnum):
59
  """
60
  Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
 
66
  DO_NOT_PAD = "do_not_pad"
67
 
68
 
69
+
70
  class TensorType(ExplicitEnum):
71
  """
72
  Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
 
78
  NUMPY = "np"
79
  JAX = "jax"
80
 
81
+
82
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
83
+ mask_token = "<mask>"
84
+ mask_token_id = token_dictionary.get("<mask>")
85
+ pad_token = "<pad>"
86
+ pad_token_id = token_dictionary.get("<pad>")
87
+ padding_side = "right"
88
+ all_special_ids = [
89
+ token_dictionary.get("<mask>"),
90
+ token_dictionary.get("<pad>")
91
+ ]
92
+ model_input_names = ["input_ids"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def _get_padding_truncation_strategies(
95
+ self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
 
 
 
 
 
 
96
  ):
97
  """
98
  Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
 
105
  # If you only set max_length, it activates truncation for max_length
106
  if max_length is not None and padding is False and truncation is False:
107
  if verbose:
108
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
 
 
109
  logger.warning(
110
  "Truncation was not explicitly activated but `max_length` is provided a specific value, "
111
  "please use `truncation=True` to explicitly truncate examples to max length. "
 
133
  padding_strategy = PaddingStrategy.MAX_LENGTH
134
  elif padding is not False:
135
  if padding is True:
136
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
 
 
137
  elif not isinstance(padding, PaddingStrategy):
138
  padding_strategy = PaddingStrategy(padding)
139
  elif isinstance(padding, PaddingStrategy):
 
173
  if padding_strategy == PaddingStrategy.MAX_LENGTH:
174
  if self.model_max_length > LARGE_INTEGER:
175
  if verbose:
176
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
 
 
177
  logger.warning(
178
  "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
179
  "Default to no padding."
 
186
  if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
187
  if self.model_max_length > LARGE_INTEGER:
188
  if verbose:
189
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
 
 
190
  logger.warning(
191
  "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
192
  "Default to no truncation."
193
  )
194
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
 
 
195
  truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
196
  else:
197
  max_length = self.model_max_length
198
 
199
  # Test if we have a padding token
200
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
 
 
201
  raise ValueError(
202
  "Asking to pad but the tokenizer does not have a padding token. "
203
  "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
 
228
  Dict[str, List[EncodedInput]],
229
  List[Dict[str, EncodedInput]],
230
  ],
231
+ class_type, # options: "gene" or "cell"
232
  padding: Union[bool, str, PaddingStrategy] = True,
233
  max_length: Optional[int] = None,
234
  pad_to_multiple_of: Optional[int] = None,
 
239
  """
240
  Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
241
  in the batch.
242
+
243
  Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
244
  ``self.pad_token_id`` and ``self.pad_token_type_id``)
245
+
246
  .. note::
247
+
248
  If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
249
  result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
250
  case of PyTorch tensors, you will lose the specific device of your tensors however.
251
+
252
  Args:
253
  encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
254
  Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
255
  List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
256
  List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
257
  well as in a PyTorch Dataloader collate function.
258
+
259
  Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
260
  see the note above for the return type.
261
  padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
262
  Select a strategy to pad the returned sequences (according to the model's padding side and padding
263
  index) among:
264
+
265
  * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
266
  single sequence if provided).
267
  * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
 
272
  Maximum length of the returned list and optionally padding length (see above).
273
  pad_to_multiple_of (:obj:`int`, `optional`):
274
  If set will pad the sequence to a multiple of the provided value.
275
+
276
  This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
277
  >= 7.5 (Volta).
278
  return_attention_mask (:obj:`bool`, `optional`):
279
  Whether to return the attention mask. If left to the default, will return the attention mask according
280
  to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
281
+
282
  `What are attention masks? <../glossary.html#attention-mask>`__
283
  return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
284
  If set, will return tensors instead of list of python integers. Acceptable values are:
285
+
286
  * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
287
  * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
288
  * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
 
291
  """
292
  # If we have a list of dicts, let's convert it in a dict of lists
293
  # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
294
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
295
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
 
 
 
 
 
296
 
297
  # The model's main input name, usually `input_ids`, has be passed for padding
298
  if self.model_input_names[0] not in encoded_inputs:
 
386
  def _pad(
387
  self,
388
  encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
389
+ class_type, # options: "gene" or "cell"
390
  max_length: Optional[int] = None,
391
  padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
392
  pad_to_multiple_of: Optional[int] = None,
 
394
  ) -> dict:
395
  """
396
  Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
397
+
398
  Args:
399
  encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
400
  max_length: maximum length of the returned list and optionally padding length (see below).
401
  Will truncate by taking into account the special tokens.
402
  padding_strategy: PaddingStrategy to use for padding.
403
+
404
  - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
405
  - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
406
  - PaddingStrategy.DO_NOT_PAD: Do not pad
407
  The tokenizer padding sides are defined in self.padding_side:
408
+
409
  - 'left': pads on the left of the sequences
410
  - 'right': pads on the right of the sequences
411
  pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
 
422
  if padding_strategy == PaddingStrategy.LONGEST:
423
  max_length = len(required_input)
424
 
425
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
 
 
 
 
426
  max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
427
 
428
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
 
 
 
429
 
430
  if needs_to_be_padded:
431
  difference = max_length - len(required_input)
432
  if self.padding_side == "right":
433
  if return_attention_mask:
434
+ encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
 
 
435
  if "token_type_ids" in encoded_inputs:
436
  encoded_inputs["token_type_ids"] = (
437
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
 
438
  )
439
  if "special_tokens_mask" in encoded_inputs:
440
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
441
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
 
 
 
 
442
  if class_type == "gene":
443
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
 
 
444
  elif self.padding_side == "left":
445
  if return_attention_mask:
446
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
 
 
447
  if "token_type_ids" in encoded_inputs:
448
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
449
+ "token_type_ids"
450
+ ]
451
  if "special_tokens_mask" in encoded_inputs:
452
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
453
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
 
 
 
 
454
  if class_type == "gene":
455
+ encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
 
 
456
  else:
457
  raise ValueError("Invalid padding strategy:" + str(self.padding_side))
458
  elif return_attention_mask and "attention_mask" not in encoded_inputs:
459
  encoded_inputs["attention_mask"] = [1] * len(required_input)
460
+
461
  return encoded_inputs
462
 
463
  def get_special_tokens_mask(
464
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
 
 
 
465
  ) -> List[int]:
466
  """
467
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
 
485
 
486
  all_special_ids = self.all_special_ids # cache the property
487
 
488
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
 
 
489
 
490
  return special_tokens_mask
491
 
492
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
 
 
493
  """
494
  Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
495
  vocabulary.
 
513
  if token is None:
514
  return None
515
 
516
+ return token_dictionary.get(token)
517
 
518
  def __len__(self):
519
+ return len(token_dictionary)
520
 
521
 
522
  # collator functions
523
 
 
524
  class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
525
  """
526
  Data collator that will dynamically pad the inputs received, as well as the labels.
 
546
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
547
  """
548
 
549
+ tokenizer = PrecollatorForGeneAndCellClassification()
550
  class_type = "gene"
551
  padding: Union[bool, str, PaddingStrategy] = True
552
  max_length: Optional[int] = None
553
  pad_to_multiple_of: Optional[int] = None
554
  label_pad_token_id: int = -100
555
+
556
  def __init__(self, *args, **kwargs) -> None:
 
557
  super().__init__(
558
+ tokenizer=self.tokenizer,
 
 
559
  padding=self.padding,
560
  max_length=self.max_length,
561
  pad_to_multiple_of=self.pad_to_multiple_of,
562
  label_pad_token_id=self.label_pad_token_id,
563
+ *args, **kwargs)
 
 
564
 
565
  def _prepare_batch(self, features):
566
  label_name = "label" if "label" in features[0].keys() else "labels"
567
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
 
 
 
 
568
  batch = self.tokenizer.pad(
569
  features,
570
  class_type=self.class_type,
 
574
  return_tensors="pt",
575
  )
576
  return batch
577
+
578
  def __call__(self, features):
579
  batch = self._prepare_batch(features)
580
 
581
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
582
  return batch
583
 
584
+
585
  class DataCollatorForCellClassification(DataCollatorForGeneClassification):
586
+
587
  class_type = "cell"
588
 
589
  def _prepare_batch(self, features):
590
+
591
  batch = super()._prepare_batch(features)
592
+
593
  # Special handling for labels.
594
  # Ensure that tensor is created with the correct type
595
  # (it should be automatically the case, but let's make sure of it.)
596
  first = features[0]
597
  if "label" in first and first["label"] is not None:
598
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
 
 
 
 
599
  dtype = torch.long if isinstance(label, int) else torch.float
600
  batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
601
+
602
  return batch
geneformer/emb_extractor.py CHANGED
@@ -1,419 +1,253 @@
1
  """
2
  Geneformer embedding extractor.
3
 
4
- **Description:**
5
-
6
- | Extracts gene or cell embeddings.
7
- | Plots cell embeddings as heatmaps or UMAPs.
8
- | Generates cell state embedding dictionary for use with InSilicoPerturber.
9
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
 
12
  # imports
13
  import logging
14
- import pickle
15
- from collections import Counter
16
- from pathlib import Path
17
-
18
  import anndata
19
  import matplotlib.pyplot as plt
 
20
  import pandas as pd
 
 
21
  import scanpy as sc
22
  import seaborn as sns
23
  import torch
24
- from tdigest import TDigest
25
- from tqdm.auto import trange
 
 
26
 
27
- from . import TOKEN_DICTIONARY_FILE
28
- from . import perturber_utils as pu
29
 
30
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
31
 
 
32
 
33
  # extract embeddings
34
- def get_embs(
35
- model,
36
- filtered_input_data,
37
- emb_mode,
38
- layer_to_quant,
39
- pad_token_id,
40
- forward_batch_size,
41
- token_gene_dict,
42
- special_token=False,
43
- summary_stat=None,
44
- silent=False,
45
- ):
46
- model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
48
-
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
- # get # of emb dims
53
- emb_dims = pu.get_model_emb_dims(model)
54
- if emb_mode == "cell":
55
- # initiate tdigests for # of emb dims
56
- embs_tdigests = [TDigest() for _ in range(emb_dims)]
57
- if emb_mode == "gene":
58
- gene_set = list(
59
- {
60
- element
61
- for sublist in filtered_input_data["input_ids"]
62
- for element in sublist
63
- }
64
- )
65
- # initiate dict with genes as keys and tdigests for # of emb dims as values
66
- embs_tdigests_dict = {
67
- k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
- }
69
-
70
- # Check if CLS and EOS token is present in the token dictionary
71
- cls_present = any("<cls>" in value for value in token_gene_dict.values())
72
- eos_present = any("<eos>" in value for value in token_gene_dict.values())
73
- if emb_mode == "cls":
74
- assert cls_present, "<cls> token missing in token dictionary"
75
- # Check to make sure that the first token of the filtered input data is cls token
76
- gene_token_dict = {v: k for k, v in token_gene_dict.items()}
77
- cls_token_id = gene_token_dict["<cls>"]
78
- assert (
79
- filtered_input_data["input_ids"][0][0] == cls_token_id
80
- ), "First token is not <cls> token value"
81
- elif emb_mode == "cell":
82
- if cls_present:
83
- logger.warning(
84
- "CLS token present in token dictionary, excluding from average."
85
- )
86
- if eos_present:
87
- logger.warning(
88
- "EOS token present in token dictionary, excluding from average."
89
- )
90
-
91
- overall_max_len = 0
92
 
93
- for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
94
- max_range = min(i + forward_batch_size, total_batch_length)
95
 
96
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
97
-
98
- max_len = int(max(minibatch["length"]))
99
- original_lens = torch.tensor(minibatch["length"], device="cuda")
100
  minibatch.set_format(type="torch")
101
 
102
  input_data_minibatch = minibatch["input_ids"]
103
- input_data_minibatch = pu.pad_tensor_list(
104
- input_data_minibatch, max_len, pad_token_id, model_input_size
105
- )
106
-
 
107
  with torch.no_grad():
108
  outputs = model(
109
- input_ids=input_data_minibatch.to("cuda"),
110
- attention_mask=pu.gen_attention_mask(minibatch),
111
  )
112
 
113
  embs_i = outputs.hidden_states[layer_to_quant]
114
-
115
  if emb_mode == "cell":
116
- if cls_present:
117
- non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
118
- if eos_present:
119
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
120
- else:
121
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
122
- else:
123
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
124
  if summary_stat is None:
125
- embs_list.append(mean_embs)
126
  elif summary_stat is not None:
127
  # update tdigests with current batch for each emb dim
128
- accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
129
- del mean_embs
130
- elif emb_mode == "gene":
131
- if summary_stat is None:
132
- embs_list.append(embs_i)
133
- elif summary_stat is not None:
134
- for h in trange(len(minibatch)):
135
- length_h = minibatch[h]["length"]
136
- input_ids_h = minibatch[h]["input_ids"][0:length_h]
137
-
138
- # double check dimensions before unsqueezing
139
- embs_i_dim = embs_i.dim()
140
- if embs_i_dim != 3:
141
- logger.error(
142
- f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
143
- )
144
- raise
145
-
146
- embs_h = embs_i[h, :, :].unsqueeze(dim=1)
147
- dict_h = dict(zip(input_ids_h, embs_h))
148
- for k in dict_h.keys():
149
- accumulate_tdigests(
150
- embs_tdigests_dict[int(k)], dict_h[k], emb_dims
151
- )
152
- del embs_h
153
- del dict_h
154
- elif emb_mode == "cls":
155
- cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
156
- embs_list.append(cls_embs)
157
- del cls_embs
158
-
159
- overall_max_len = max(overall_max_len, max_len)
160
  del outputs
161
  del minibatch
162
  del input_data_minibatch
163
  del embs_i
164
-
165
- torch.cuda.empty_cache()
166
-
167
  if summary_stat is None:
168
- if (emb_mode == "cell") or (emb_mode == "cls"):
169
- embs_stack = torch.cat(embs_list, dim=0)
170
- elif emb_mode == "gene":
171
- embs_stack = pu.pad_tensor_list(
172
- embs_list,
173
- overall_max_len,
174
- pad_token_id,
175
- model_input_size,
176
- 1,
177
- pu.pad_3d_tensor,
178
- )
179
-
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
182
- if emb_mode == "cell":
183
- if summary_stat == "mean":
184
- summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
- elif summary_stat == "median":
186
- summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
- embs_stack = torch.tensor(summary_emb_list)
188
- elif emb_mode == "gene":
189
- if summary_stat == "mean":
190
- [
191
- update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
192
- for gene in embs_tdigests_dict.keys()
193
- ]
194
- elif summary_stat == "median":
195
- [
196
- update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
197
- for gene in embs_tdigests_dict.keys()
198
- ]
199
- return embs_tdigests_dict
200
 
201
  return embs_stack
202
 
 
 
 
 
 
203
 
204
- def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
205
- # note: tdigest batch update known to be slow so updating serially
206
- [
207
- embs_tdigests[j].update(mean_embs[i, j].item())
208
- for i in range(mean_embs.size(0))
209
- for j in range(emb_dims)
210
- ]
211
-
212
-
213
- def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
214
- embs_tdigests_dict[gene] = accumulate_tdigests(
215
- embs_tdigests_dict[gene], gene_embs, emb_dims
216
- )
217
-
218
-
219
- def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
220
- embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
221
-
222
-
223
- def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
224
- embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
225
-
226
-
227
- def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
228
- length_h = minibatch[h]["length"]
229
- input_ids_h = minibatch[h]["input_ids"][0:length_h]
230
- embs_h = embs_i[h, :, :].unsqueeze(dim=1)
231
- dict_h = dict(zip(input_ids_h, embs_h))
232
- [
233
- update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
234
- for k in dict_h.keys()
235
- ]
236
-
237
-
238
- def tdigest_mean(embs_tdigests, emb_dims):
239
- return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
240
-
241
-
242
- def tdigest_median(embs_tdigests, emb_dims):
243
- return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
244
-
245
 
246
- def label_cell_embs(embs, downsampled_data, emb_labels):
247
- embs_df = pd.DataFrame(embs.cpu().numpy())
248
  if emb_labels is not None:
249
  for label in emb_labels:
250
  emb_label = downsampled_data[label]
251
  embs_df[label] = emb_label
252
  return embs_df
253
 
254
-
255
- def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
- gene_set = {
257
- element for sublist in downsampled_data["input_ids"] for element in sublist
258
- }
259
- gene_emb_dict = {k: [] for k in gene_set}
260
- for i in range(embs.size()[0]):
261
- length = downsampled_data[i]["length"]
262
- dict_i = dict(
263
- zip(
264
- downsampled_data[i]["input_ids"][0:length],
265
- embs[i, :, :].unsqueeze(dim=1),
266
- )
267
- )
268
- for k in dict_i.keys():
269
- gene_emb_dict[k].append(dict_i[k])
270
- for k in gene_emb_dict.keys():
271
- gene_emb_dict[k] = (
272
- torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
- .cpu()
274
- .numpy()
275
- )
276
- embs_df = pd.DataFrame(gene_emb_dict).T
277
- embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
- return embs_df
279
-
280
-
281
- def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
282
- only_embs_df = embs_df.iloc[:, :emb_dims]
283
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
- only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
- str
286
- )
287
  vars_dict = {"embs": only_embs_df.columns}
288
- obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
 
289
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
- sc.tl.pca(adata, svd_solver="arpack")
291
- sc.pp.neighbors(adata, random_state=seed)
292
- sc.tl.umap(adata, random_state=seed)
293
- sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
294
  sns.set_style("white")
295
- default_kwargs_dict = {"size": 200}
296
  if kwargs_dict is not None:
297
  default_kwargs_dict.update(kwargs_dict)
298
-
299
- cats = set(embs_df[label])
300
-
301
- with plt.rc_context():
302
- ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
- ax.legend(
304
- markerscale=2,
305
- frameon=False,
306
- loc="center left",
307
- bbox_to_anchor=(1, 0.5),
308
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
- )
310
- plt.show()
311
- plt.savefig(output_file, bbox_inches="tight")
312
-
313
 
314
  def gen_heatmap_class_colors(labels, df):
315
- pal = sns.cubehelix_palette(
316
- len(Counter(labels).keys()),
317
- light=0.9,
318
- dark=0.1,
319
- hue=1,
320
- reverse=True,
321
- start=1,
322
- rot=-2,
323
- )
324
  lut = dict(zip(map(str, Counter(labels).keys()), pal))
325
  colors = pd.Series(labels, index=df.index).map(lut)
326
  return colors
327
-
328
-
329
  def gen_heatmap_class_dict(classes, label_colors_series):
330
- class_color_dict_df = pd.DataFrame(
331
- {"classes": classes, "color": label_colors_series}
332
- )
333
  class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
334
- return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
335
-
336
-
337
  def make_colorbar(embs_df, label):
338
- labels = list(embs_df[label])
339
 
 
 
340
  cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
341
  label_colors = pd.DataFrame(cell_type_colors, columns=[label])
342
 
 
 
 
 
 
 
 
343
  # create dictionary for colors and classes
344
  label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
345
  return label_colors, label_color_dict
346
-
347
-
348
  def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
349
  sns.set_style("white")
350
  sns.set(font_scale=2)
351
  plt.figure(figsize=(15, 15), dpi=150)
352
  label_colors, label_color_dict = make_colorbar(embs_df, label)
353
-
354
- default_kwargs_dict = {
355
- "row_cluster": True,
356
- "col_cluster": True,
357
- "row_colors": label_colors,
358
- "standard_scale": 1,
359
- "linewidths": 0,
360
- "xticklabels": False,
361
- "yticklabels": False,
362
- "figsize": (15, 15),
363
- "center": 0,
364
- "cmap": "magma",
365
- }
366
-
367
  if kwargs_dict is not None:
368
  default_kwargs_dict.update(kwargs_dict)
369
- g = sns.clustermap(
370
- embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
371
- )
372
 
373
  plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
374
 
375
  for label_color in list(label_color_dict.keys()):
376
- g.ax_col_dendrogram.bar(
377
- 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
378
- )
379
 
380
- g.ax_col_dendrogram.legend(
381
- title=f"{label}",
382
- loc="lower center",
383
- ncol=4,
384
- bbox_to_anchor=(0.5, 1),
385
- facecolor="white",
386
- )
387
- plt.show()
388
- logger.info(f"Output file: {output_file}")
389
- plt.savefig(output_file, bbox_inches="tight")
390
 
 
391
 
392
  class EmbExtractor:
393
  valid_option_dict = {
394
- "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
395
  "num_classes": {int},
396
- "emb_mode": {"cls", "cell", "gene"},
397
  "cell_emb_style": {"mean_pool"},
398
- "gene_emb_style": {"mean_pool"},
399
  "filter_data": {None, dict},
400
  "max_ncells": {None, int},
401
  "emb_layer": {-1, 0},
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
405
- "token_dictionary_file": {None, str},
406
  "nproc": {int},
407
- "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
408
  }
409
-
410
  def __init__(
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
- emb_mode="cls",
415
  cell_emb_style="mean_pool",
416
- gene_emb_style="mean_pool",
417
  filter_data=None,
418
  max_ncells=1000,
419
  emb_layer=-1,
@@ -422,442 +256,238 @@ class EmbExtractor:
422
  forward_batch_size=100,
423
  nproc=4,
424
  summary_stat=None,
425
- token_dictionary_file=None,
426
  ):
427
  """
428
  Initialize embedding extractor.
429
 
430
- **Parameters:**
431
-
432
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
433
- | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
434
  num_classes : int
435
- | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
- | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
437
- emb_mode : {"cls", "cell", "gene"}
438
- | Whether to output CLS, cell, or gene embeddings.
439
- | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
440
- cell_emb_style : {"mean_pool"}
441
- | Method for summarizing cell embeddings if not using CLS token.
442
- | Currently only option is mean pooling of gene embeddings for given cell.
443
- gene_emb_style : "mean_pool"
444
- | Method for summarizing gene embeddings.
445
- | Currently only option is mean pooling of contextual gene embeddings for given gene.
446
  filter_data : None, dict
447
- | Default is to extract embeddings from all input data.
448
- | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
449
  max_ncells : None, int
450
- | Maximum number of cells to extract embeddings from.
451
- | Default is 1000 cells randomly sampled from input data.
452
- | If None, will extract embeddings from all cells.
453
  emb_layer : {-1, 0}
454
- | Embedding layer to extract.
455
- | The last layer is most specifically weighted to optimize the given learning objective.
456
- | Generally, it is best to extract the 2nd to last layer to get a more general representation.
457
- | -1: 2nd to last layer
458
- | 0: last layer
459
  emb_label : None, list
460
- | List of column name(s) in .dataset to add as labels to embedding output.
461
  labels_to_plot : None, list
462
- | Cell labels to plot.
463
- | Shown as color bar in heatmap.
464
- | Shown as cell color in umap.
465
- | Plotting umap requires labels to plot.
466
  forward_batch_size : int
467
- | Batch size for forward pass.
468
  nproc : int
469
- | Number of CPU processes to use.
470
- summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
471
- | If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
472
- | If mean or median, outputs only approximated mean or median embedding of input data.
473
- | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
- | Non-exact is slower but more memory-efficient.
475
  token_dictionary_file : Path
476
- | Default is the Geneformer token dictionary
477
- | Path to pickle file containing token dictionary (Ensembl ID:token).
478
-
479
- **Examples:**
480
-
481
- .. code-block :: python
482
-
483
- >>> from geneformer import EmbExtractor
484
- >>> embex = EmbExtractor(model_type="CellClassifier",
485
- ... num_classes=3,
486
- ... emb_mode="cell",
487
- ... filter_data={"cell_type":["cardiomyocyte"]},
488
- ... max_ncells=1000,
489
- ... emb_layer=-1,
490
- ... emb_label=["disease", "cell_type"],
491
- ... labels_to_plot=["disease", "cell_type"])
492
-
493
  """
494
 
495
  self.model_type = model_type
496
  self.num_classes = num_classes
497
  self.emb_mode = emb_mode
498
  self.cell_emb_style = cell_emb_style
499
- self.gene_emb_style = gene_emb_style
500
  self.filter_data = filter_data
501
  self.max_ncells = max_ncells
502
  self.emb_layer = emb_layer
503
  self.emb_label = emb_label
504
  self.labels_to_plot = labels_to_plot
505
- self.token_dictionary_file = token_dictionary_file
506
  self.forward_batch_size = forward_batch_size
507
  self.nproc = nproc
508
- if (summary_stat is not None) and ("exact" in summary_stat):
509
- self.summary_stat = None
510
- self.exact_summary_stat = summary_stat
511
- else:
512
- self.summary_stat = summary_stat
513
- self.exact_summary_stat = None
514
 
515
  self.validate_options()
516
 
517
  # load token dictionary (Ensembl IDs:token)
518
- if self.token_dictionary_file is None:
519
- token_dictionary_file = TOKEN_DICTIONARY_FILE
520
  with open(token_dictionary_file, "rb") as f:
521
  self.gene_token_dict = pickle.load(f)
522
 
523
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
524
  self.pad_token_id = self.gene_token_dict.get("<pad>")
525
-
 
526
  def validate_options(self):
 
 
 
 
 
 
 
 
527
  # confirm arguments are within valid options and compatible with each other
528
- for attr_name, valid_options in self.valid_option_dict.items():
529
  attr_value = self.__dict__[attr_name]
530
- if not isinstance(attr_value, (list, dict)):
531
  if attr_value in valid_options:
532
  continue
533
  valid_type = False
534
  for option in valid_options:
535
- if (option in [int, list, dict, bool, str]) and isinstance(
536
- attr_value, option
537
- ):
538
  valid_type = True
539
  break
540
  if valid_type:
541
  continue
542
  logger.error(
543
- f"Invalid option for {attr_name}. "
544
  f"Valid options for {attr_name}: {valid_options}"
545
  )
546
  raise
547
-
548
  if self.filter_data is not None:
549
- for key, value in self.filter_data.items():
550
- if not isinstance(value, list):
551
  self.filter_data[key] = [value]
552
  logger.warning(
553
- "Values in filter_data dict must be lists. "
554
- f"Changing {key} value to list ([{value}])."
555
- )
556
-
557
- def extract_embs(
558
- self,
559
- model_directory,
560
- input_data_file,
561
- output_directory,
562
- output_prefix,
563
- output_torch_embs=False,
564
- cell_state=None,
565
- ):
566
  """
567
  Extract embeddings from input data and save as results in output_directory.
568
 
569
- **Parameters:**
570
-
571
  model_directory : Path
572
- | Path to directory containing model
573
  input_data_file : Path
574
- | Path to directory containing .dataset inputs
575
  output_directory : Path
576
- | Path to directory where embedding data will be saved as csv
577
  output_prefix : str
578
- | Prefix for output file
579
- output_torch_embs : bool
580
- | Whether or not to also output the embeddings as a tensor.
581
- | Note, if true, will output embeddings as both dataframe and tensor.
582
- cell_state : dict
583
- | Cell state key and value for state embedding extraction.
584
-
585
- **Examples:**
586
-
587
- .. code-block :: python
588
-
589
- >>> embs = embex.extract_embs("path/to/model",
590
- ... "path/to/input_data",
591
- ... "path/to/output_directory",
592
- ... "output_prefix")
593
-
594
  """
595
 
596
- filtered_input_data = pu.load_and_filter(
597
- self.filter_data, self.nproc, input_data_file
598
- )
599
-
600
- # Check to make sure that all the labels exist in the tokenized data:
601
- if self.emb_label is not None:
602
- for label in self.emb_label:
603
- assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
604
-
605
- if cell_state is not None:
606
- filtered_input_data = pu.filter_by_dict(
607
- filtered_input_data, cell_state, self.nproc
608
- )
609
- downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
610
- model = pu.load_model(
611
- self.model_type, self.num_classes, model_directory, mode="eval"
612
- )
613
- layer_to_quant = pu.quant_layers(model) + self.emb_layer
614
- embs = get_embs(
615
- model=model,
616
- filtered_input_data=downsampled_data,
617
- emb_mode=self.emb_mode,
618
- layer_to_quant=layer_to_quant,
619
- pad_token_id=self.pad_token_id,
620
- forward_batch_size=self.forward_batch_size,
621
- token_gene_dict=self.token_gene_dict,
622
- summary_stat=self.summary_stat,
623
- )
624
-
625
- if self.emb_mode == "cell":
626
- if self.summary_stat is None:
627
- embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
628
- elif self.summary_stat is not None:
629
- embs_df = pd.DataFrame(embs.cpu().numpy()).T
630
- elif self.emb_mode == "gene":
631
- if self.summary_stat is None:
632
- embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
633
- elif self.summary_stat is not None:
634
- embs_df = pd.DataFrame(embs).T
635
- embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
636
- elif self.emb_mode == "cls":
637
- embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
638
 
639
  # save embeddings to output_path
640
- if cell_state is None:
641
- output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
642
- embs_df.to_csv(output_path)
643
-
644
- if self.exact_summary_stat == "exact_mean":
645
- embs = embs.mean(dim=0)
646
- emb_dims = pu.get_model_emb_dims(model)
647
- embs_df = pd.DataFrame(
648
- embs_df[0 : emb_dims - 1].mean(axis="rows"),
649
- columns=[self.exact_summary_stat],
650
- ).T
651
- elif self.exact_summary_stat == "exact_median":
652
- embs = torch.median(embs, dim=0)[0]
653
- emb_dims = pu.get_model_emb_dims(model)
654
- embs_df = pd.DataFrame(
655
- embs_df[0 : emb_dims - 1].median(axis="rows"),
656
- columns=[self.exact_summary_stat],
657
- ).T
658
-
659
- if cell_state is not None:
660
- return embs
661
- else:
662
- if output_torch_embs:
663
- return embs_df, embs
664
- else:
665
- return embs_df
666
-
667
- def get_state_embs(
668
- self,
669
- cell_states_to_model,
670
- model_directory,
671
- input_data_file,
672
- output_directory,
673
- output_prefix,
674
- output_torch_embs=True,
675
- ):
676
- """
677
- Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
678
-
679
- **Parameters:**
680
-
681
- cell_states_to_model : None, dict
682
- | Cell states to model if testing perturbations that achieve goal state change.
683
- | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
684
- | state_key: key specifying name of column in .dataset that defines the start/goal states
685
- | start_state: value in the state_key column that specifies the start state
686
- | goal_state: value in the state_key column taht specifies the goal end state
687
- | alt_states: list of values in the state_key column that specify the alternate end states
688
- | For example:
689
- | {"state_key": "disease",
690
- | "start_state": "dcm",
691
- | "goal_state": "nf",
692
- | "alt_states": ["hcm", "other1", "other2"]}
693
- model_directory : Path
694
- | Path to directory containing model
695
- input_data_file : Path
696
- | Path to directory containing .dataset inputs
697
- output_directory : Path
698
- | Path to directory where embedding data will be saved as csv
699
- output_prefix : str
700
- | Prefix for output file
701
- output_torch_embs : bool
702
- | Whether or not to also output the embeddings as a tensor.
703
- | Note, if true, will output embeddings as both dataframe and tensor.
704
-
705
- **Outputs**
706
-
707
- | Outputs state_embs_dict for use with in silico perturber.
708
- | Format is dictionary of embedding positions of each cell state to model shifts from/towards.
709
- | Keys specify each possible cell state to model.
710
- | Values are target embedding positions as torch.tensor.
711
- | For example:
712
- | {"nf": emb_nf,
713
- | "hcm": emb_hcm,
714
- | "dcm": emb_dcm,
715
- | "other1": emb_other1,
716
- | "other2": emb_other2}
717
- """
718
-
719
- pu.validate_cell_states_to_model(cell_states_to_model)
720
- valid_summary_stats = ["exact_mean", "exact_median"]
721
- if self.exact_summary_stat not in valid_summary_stats:
722
- logger.error(
723
- "For extracting state embs, summary_stat in EmbExtractor "
724
- f"must be set to option in {valid_summary_stats}"
725
- )
726
- raise
727
-
728
- if self.emb_label is not None:
729
- logger.error(
730
- "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
731
- )
732
- raise
733
 
734
- state_embs_dict = dict()
735
- state_key = cell_states_to_model["state_key"]
736
- for k, v in cell_states_to_model.items():
737
- if k == "state_key":
738
- continue
739
- elif (k == "start_state") or (k == "goal_state"):
740
- state_embs_dict[v] = self.extract_embs(
741
- model_directory,
742
- input_data_file,
743
- output_directory,
744
- output_prefix,
745
- output_torch_embs,
746
- cell_state={state_key: v},
747
- )
748
- else: # k == "alt_states"
749
- for alt_state in v:
750
- state_embs_dict[alt_state] = self.extract_embs(
751
- model_directory,
752
- input_data_file,
753
- output_directory,
754
- output_prefix,
755
- output_torch_embs,
756
- cell_state={state_key: alt_state},
757
- )
758
-
759
- output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
760
- with open(output_path, "wb") as fp:
761
- pickle.dump(state_embs_dict, fp)
762
-
763
- return state_embs_dict
764
-
765
- def plot_embs(
766
- self,
767
- embs,
768
- plot_style,
769
- output_directory,
770
- output_prefix,
771
- max_ncells_to_plot=1000,
772
- kwargs_dict=None,
773
- ):
774
  """
775
  Plot embeddings, coloring by provided labels.
776
 
777
- **Parameters:**
778
-
779
  embs : pandas.core.frame.DataFrame
780
- | Pandas dataframe containing embeddings output from extract_embs
781
  plot_style : str
782
- | Style of plot: "heatmap" or "umap"
783
  output_directory : Path
784
- | Path to directory where plots will be saved as pdf
785
  output_prefix : str
786
- | Prefix for output file
787
  max_ncells_to_plot : None, int
788
- | Maximum number of cells to plot.
789
- | Default is 1000 cells randomly sampled from embeddings.
790
- | If None, will plot embeddings from all cells.
791
  kwargs_dict : dict
792
- | Dictionary of kwargs to pass to plotting function.
793
-
794
- **Examples:**
795
-
796
- .. code-block :: python
797
-
798
- >>> embex.plot_embs(embs=embs,
799
- ... plot_style="heatmap",
800
- ... output_directory="path/to/output_directory",
801
- ... output_prefix="output_prefix")
802
-
803
  """
804
-
805
- if plot_style not in ["heatmap", "umap"]:
806
  logger.error(
807
- "Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
 
808
  )
809
  raise
810
-
811
  if (plot_style == "umap") and (self.labels_to_plot is None):
812
- logger.error("Plotting UMAP requires 'labels_to_plot'. ")
 
 
813
  raise
814
-
815
- if max_ncells_to_plot is not None:
816
- if max_ncells_to_plot > self.max_ncells:
817
- max_ncells_to_plot = self.max_ncells
818
- logger.warning(
819
- "max_ncells_to_plot must be <= max_ncells. "
820
- f"Changing max_ncells_to_plot to {self.max_ncells}."
821
- )
822
- elif max_ncells_to_plot < self.max_ncells:
823
- embs = embs.sample(max_ncells_to_plot, axis=0)
824
-
825
  if self.emb_label is None:
826
  label_len = 0
827
  else:
828
  label_len = len(self.emb_label)
829
-
830
  emb_dims = embs.shape[1] - label_len
831
-
832
  if self.emb_label is None:
833
  emb_labels = None
834
  else:
835
  emb_labels = embs.columns[emb_dims:]
836
-
837
  if plot_style == "umap":
838
  for label in self.labels_to_plot:
839
  if label not in emb_labels:
840
  logger.warning(
841
- f"Label {label} from labels_to_plot "
842
- f"not present in provided embeddings dataframe."
843
- )
844
  continue
845
- output_prefix_label = output_prefix + f"_umap_{label}"
846
- output_file = (
847
- Path(output_directory) / output_prefix_label
848
- ).with_suffix(".pdf")
849
- plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
850
-
851
  if plot_style == "heatmap":
852
  for label in self.labels_to_plot:
853
  if label not in emb_labels:
854
  logger.warning(
855
- f"Label {label} from labels_to_plot "
856
- f"not present in provided embeddings dataframe."
857
- )
858
  continue
859
  output_prefix_label = output_prefix + f"_heatmap_{label}"
860
- output_file = (
861
- Path(output_directory) / output_prefix_label
862
- ).with_suffix(".pdf")
863
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
1
  """
2
  Geneformer embedding extractor.
3
 
4
+ Usage:
5
+ from geneformer import EmbExtractor
6
+ embex = EmbExtractor(model_type="CellClassifier",
7
+ num_classes=3,
8
+ emb_mode="cell",
9
+ cell_emb_style="mean_pool",
10
+ filter_data={"cell_type":["cardiomyocyte"]},
11
+ max_ncells=1000,
12
+ max_ncells_to_plot=1000,
13
+ emb_layer=-1,
14
+ emb_label=["disease","cell_type"],
15
+ labels_to_plot=["disease","cell_type"],
16
+ forward_batch_size=100,
17
+ nproc=16,
18
+ summary_stat=None)
19
+ embs = embex.extract_embs("path/to/model",
20
+ "path/to/input_data",
21
+ "path/to/output_directory",
22
+ "output_prefix")
23
+ embex.plot_embs(embs=embs,
24
+ plot_style="heatmap",
25
+ output_directory="path/to/output_directory",
26
+ output_prefix="output_prefix")
27
+
28
  """
29
 
30
  # imports
31
  import logging
 
 
 
 
32
  import anndata
33
  import matplotlib.pyplot as plt
34
+ import numpy as np
35
  import pandas as pd
36
+ import pickle
37
+ from tdigest import TDigest
38
  import scanpy as sc
39
  import seaborn as sns
40
  import torch
41
+ from collections import Counter
42
+ from pathlib import Path
43
+ from tqdm.notebook import trange
44
+ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
 
46
+ from .tokenizer import TOKEN_DICTIONARY_FILE
 
47
 
48
+ from .in_silico_perturber import downsample_and_sort, \
49
+ gen_attention_mask, \
50
+ get_model_input_size, \
51
+ load_and_filter, \
52
+ load_model, \
53
+ mean_nonpadding_embs, \
54
+ pad_tensor_list, \
55
+ quant_layers
56
 
57
+ logger = logging.getLogger(__name__)
58
 
59
  # extract embeddings
60
+ def get_embs(model,
61
+ filtered_input_data,
62
+ emb_mode,
63
+ layer_to_quant,
64
+ pad_token_id,
65
+ forward_batch_size,
66
+ summary_stat):
67
+
68
+ model_input_size = get_model_input_size(model)
 
 
 
 
69
  total_batch_length = len(filtered_input_data)
70
+
71
  if summary_stat is None:
72
  embs_list = []
73
  elif summary_stat is not None:
74
+ # test embedding extraction for example cell and extract # emb dims
75
+ example = filtered_input_data.select([i for i in range(1)])
76
+ example.set_format(type="torch")
77
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
78
+ # initiate tdigests for # of emb dims
79
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ for i in trange(0, total_batch_length, forward_batch_size):
82
+ max_range = min(i+forward_batch_size, total_batch_length)
83
 
84
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
85
+ max_len = max(minibatch["length"])
86
+ original_lens = torch.tensor(minibatch["length"]).to("cuda")
 
87
  minibatch.set_format(type="torch")
88
 
89
  input_data_minibatch = minibatch["input_ids"]
90
+ input_data_minibatch = pad_tensor_list(input_data_minibatch,
91
+ max_len,
92
+ pad_token_id,
93
+ model_input_size)
94
+
95
  with torch.no_grad():
96
  outputs = model(
97
+ input_ids = input_data_minibatch.to("cuda"),
98
+ attention_mask = gen_attention_mask(minibatch)
99
  )
100
 
101
  embs_i = outputs.hidden_states[layer_to_quant]
102
+
103
  if emb_mode == "cell":
104
+ mean_embs = mean_nonpadding_embs(embs_i, original_lens)
 
 
 
 
 
 
 
105
  if summary_stat is None:
106
+ embs_list += [mean_embs]
107
  elif summary_stat is not None:
108
  # update tdigests with current batch for each emb dim
109
+ # note: tdigest batch update known to be slow so updating serially
110
+ [embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  del outputs
113
  del minibatch
114
  del input_data_minibatch
115
  del embs_i
116
+ del mean_embs
117
+ torch.cuda.empty_cache()
118
+
119
  if summary_stat is None:
120
+ embs_stack = torch.cat(embs_list)
 
 
 
 
 
 
 
 
 
 
 
121
  # calculate summary stat embs from approximated tdigests
122
  elif summary_stat is not None:
123
+ if summary_stat == "mean":
124
+ summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
125
+ elif summary_stat == "median":
126
+ summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
127
+ embs_stack = torch.tensor(summary_emb_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  return embs_stack
130
 
131
+ def test_emb(model, example, layer_to_quant):
132
+ with torch.no_grad():
133
+ outputs = model(
134
+ input_ids = example.to("cuda")
135
+ )
136
 
137
+ embs_test = outputs.hidden_states[layer_to_quant]
138
+ return embs_test.size()[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ def label_embs(embs, downsampled_data, emb_labels):
141
+ embs_df = pd.DataFrame(embs.cpu())
142
  if emb_labels is not None:
143
  for label in emb_labels:
144
  emb_label = downsampled_data[label]
145
  embs_df[label] = emb_label
146
  return embs_df
147
 
148
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
149
+ only_embs_df = embs_df.iloc[:,:emb_dims]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
151
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
 
 
152
  vars_dict = {"embs": only_embs_df.columns}
153
+ obs_dict = {"cell_id": list(only_embs_df.index),
154
+ f"{label}": list(embs_df[label])}
155
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
156
+ sc.tl.pca(adata, svd_solver='arpack')
157
+ sc.pp.neighbors(adata)
158
+ sc.tl.umap(adata)
159
+ sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
160
  sns.set_style("white")
161
+ default_kwargs_dict = {"palette":"Set2", "size":200}
162
  if kwargs_dict is not None:
163
  default_kwargs_dict.update(kwargs_dict)
164
+
165
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def gen_heatmap_class_colors(labels, df):
168
+ pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
 
 
 
 
 
 
 
 
169
  lut = dict(zip(map(str, Counter(labels).keys()), pal))
170
  colors = pd.Series(labels, index=df.index).map(lut)
171
  return colors
172
+
 
173
  def gen_heatmap_class_dict(classes, label_colors_series):
174
+ class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
 
 
175
  class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
176
+ return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
177
+
 
178
  def make_colorbar(embs_df, label):
 
179
 
180
+ labels = list(embs_df[label])
181
+
182
  cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
183
  label_colors = pd.DataFrame(cell_type_colors, columns=[label])
184
 
185
+ for i,row in label_colors.iterrows():
186
+ colors=row[0]
187
+ if len(colors)!=3 or any(np.isnan(colors)):
188
+ print(i,colors)
189
+
190
+ label_colors.isna().sum()
191
+
192
  # create dictionary for colors and classes
193
  label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
194
  return label_colors, label_color_dict
195
+
 
196
  def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
197
  sns.set_style("white")
198
  sns.set(font_scale=2)
199
  plt.figure(figsize=(15, 15), dpi=150)
200
  label_colors, label_color_dict = make_colorbar(embs_df, label)
201
+
202
+ default_kwargs_dict = {"row_cluster": True,
203
+ "col_cluster": True,
204
+ "row_colors": label_colors,
205
+ "standard_scale": 1,
206
+ "linewidths": 0,
207
+ "xticklabels": False,
208
+ "yticklabels": False,
209
+ "figsize": (15,15),
210
+ "center": 0,
211
+ "cmap": "magma"}
212
+
 
 
213
  if kwargs_dict is not None:
214
  default_kwargs_dict.update(kwargs_dict)
215
+ g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
 
 
216
 
217
  plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
218
 
219
  for label_color in list(label_color_dict.keys()):
220
+ g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
 
 
221
 
222
+ l1 = g.ax_col_dendrogram.legend(title=f"{label}",
223
+ loc="lower center",
224
+ ncol=4,
225
+ bbox_to_anchor=(0.5, 1),
226
+ facecolor="white")
 
 
 
 
 
227
 
228
+ plt.savefig(output_file, bbox_inches='tight')
229
 
230
  class EmbExtractor:
231
  valid_option_dict = {
232
+ "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
233
  "num_classes": {int},
234
+ "emb_mode": {"cell","gene"},
235
  "cell_emb_style": {"mean_pool"},
 
236
  "filter_data": {None, dict},
237
  "max_ncells": {None, int},
238
  "emb_layer": {-1, 0},
239
  "emb_label": {None, list},
240
  "labels_to_plot": {None, list},
241
  "forward_batch_size": {int},
 
242
  "nproc": {int},
243
+ "summary_stat": {None, "mean", "median"},
244
  }
 
245
  def __init__(
246
  self,
247
  model_type="Pretrained",
248
  num_classes=0,
249
+ emb_mode="cell",
250
  cell_emb_style="mean_pool",
 
251
  filter_data=None,
252
  max_ncells=1000,
253
  emb_layer=-1,
 
256
  forward_batch_size=100,
257
  nproc=4,
258
  summary_stat=None,
259
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
260
  ):
261
  """
262
  Initialize embedding extractor.
263
 
264
+ Parameters
265
+ ----------
266
+ model_type : {"Pretrained","GeneClassifier","CellClassifier"}
267
+ Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
268
  num_classes : int
269
+ If model is a gene or cell classifier, specify number of classes it was trained to classify.
270
+ For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
271
+ emb_mode : {"cell","gene"}
272
+ Whether to output cell or gene embeddings.
273
+ cell_emb_style : "mean_pool"
274
+ Method for summarizing cell embeddings.
275
+ Currently only option is mean pooling of gene embeddings for given cell.
 
 
 
 
276
  filter_data : None, dict
277
+ Default is to extract embeddings from all input data.
278
+ Otherwise, dictionary specifying .dataset column name and list of values to filter by.
279
  max_ncells : None, int
280
+ Maximum number of cells to extract embeddings from.
281
+ Default is 1000 cells randomly sampled from input data.
282
+ If None, will extract embeddings from all cells.
283
  emb_layer : {-1, 0}
284
+ Embedding layer to extract.
285
+ The last layer is most specifically weighted to optimize the given learning objective.
286
+ Generally, it is best to extract the 2nd to last layer to get a more general representation.
287
+ -1: 2nd to last layer
288
+ 0: last layer
289
  emb_label : None, list
290
+ List of column name(s) in .dataset to add as labels to embedding output.
291
  labels_to_plot : None, list
292
+ Cell labels to plot.
293
+ Shown as color bar in heatmap.
294
+ Shown as cell color in umap.
295
+ Plotting umap requires labels to plot.
296
  forward_batch_size : int
297
+ Batch size for forward pass.
298
  nproc : int
299
+ Number of CPU processes to use.
300
+ summary_stat : {None, "mean", "median"}
301
+ If not None, outputs only approximated mean or median embedding of input data.
302
+ Recommended if encountering memory constraints while generating goal embedding positions.
303
+ Slower but more memory-efficient.
 
304
  token_dictionary_file : Path
305
+ Path to pickle file containing token dictionary (Ensembl ID:token).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  """
307
 
308
  self.model_type = model_type
309
  self.num_classes = num_classes
310
  self.emb_mode = emb_mode
311
  self.cell_emb_style = cell_emb_style
 
312
  self.filter_data = filter_data
313
  self.max_ncells = max_ncells
314
  self.emb_layer = emb_layer
315
  self.emb_label = emb_label
316
  self.labels_to_plot = labels_to_plot
 
317
  self.forward_batch_size = forward_batch_size
318
  self.nproc = nproc
319
+ self.summary_stat = summary_stat
 
 
 
 
 
320
 
321
  self.validate_options()
322
 
323
  # load token dictionary (Ensembl IDs:token)
 
 
324
  with open(token_dictionary_file, "rb") as f:
325
  self.gene_token_dict = pickle.load(f)
326
 
 
327
  self.pad_token_id = self.gene_token_dict.get("<pad>")
328
+
329
+
330
  def validate_options(self):
331
+ # first disallow options under development
332
+ if self.emb_mode == "gene":
333
+ logger.error(
334
+ "Extraction and plotting of gene-level embeddings currently under development. " \
335
+ "Current valid option for 'emb_mode': 'cell'"
336
+ )
337
+ raise
338
+
339
  # confirm arguments are within valid options and compatible with each other
340
+ for attr_name,valid_options in self.valid_option_dict.items():
341
  attr_value = self.__dict__[attr_name]
342
+ if type(attr_value) not in {list, dict}:
343
  if attr_value in valid_options:
344
  continue
345
  valid_type = False
346
  for option in valid_options:
347
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
 
 
348
  valid_type = True
349
  break
350
  if valid_type:
351
  continue
352
  logger.error(
353
+ f"Invalid option for {attr_name}. " \
354
  f"Valid options for {attr_name}: {valid_options}"
355
  )
356
  raise
357
+
358
  if self.filter_data is not None:
359
+ for key,value in self.filter_data.items():
360
+ if type(value) != list:
361
  self.filter_data[key] = [value]
362
  logger.warning(
363
+ "Values in filter_data dict must be lists. " \
364
+ f"Changing {key} value to list ([{value}]).")
365
+
366
+ def extract_embs(self,
367
+ model_directory,
368
+ input_data_file,
369
+ output_directory,
370
+ output_prefix):
 
 
 
 
 
371
  """
372
  Extract embeddings from input data and save as results in output_directory.
373
 
374
+ Parameters
375
+ ----------
376
  model_directory : Path
377
+ Path to directory containing model
378
  input_data_file : Path
379
+ Path to directory containing .dataset inputs
380
  output_directory : Path
381
+ Path to directory where embedding data will be saved as csv
382
  output_prefix : str
383
+ Prefix for output file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  """
385
 
386
+ filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
387
+ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
388
+ model = load_model(self.model_type, self.num_classes, model_directory)
389
+ layer_to_quant = quant_layers(model)+self.emb_layer
390
+ embs = get_embs(model,
391
+ downsampled_data,
392
+ self.emb_mode,
393
+ layer_to_quant,
394
+ self.pad_token_id,
395
+ self.forward_batch_size,
396
+ self.summary_stat)
397
+
398
+ if self.summary_stat is None:
399
+ embs_df = label_embs(embs, downsampled_data, self.emb_label)
400
+ elif self.summary_stat is not None:
401
+ embs_df = pd.DataFrame(embs.cpu()).T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  # save embeddings to output_path
404
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
405
+ embs_df.to_csv(output_path)
406
+
407
+ return embs_df
408
+
409
+ def plot_embs(self,
410
+ embs,
411
+ plot_style,
412
+ output_directory,
413
+ output_prefix,
414
+ max_ncells_to_plot=1000,
415
+ kwargs_dict=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  """
418
  Plot embeddings, coloring by provided labels.
419
 
420
+ Parameters
421
+ ----------
422
  embs : pandas.core.frame.DataFrame
423
+ Pandas dataframe containing embeddings output from extract_embs
424
  plot_style : str
425
+ Style of plot: "heatmap" or "umap"
426
  output_directory : Path
427
+ Path to directory where plots will be saved as pdf
428
  output_prefix : str
429
+ Prefix for output file
430
  max_ncells_to_plot : None, int
431
+ Maximum number of cells to plot.
432
+ Default is 1000 cells randomly sampled from embeddings.
433
+ If None, will plot embeddings from all cells.
434
  kwargs_dict : dict
435
+ Dictionary of kwargs to pass to plotting function.
 
 
 
 
 
 
 
 
 
 
436
  """
437
+
438
+ if plot_style not in ["heatmap","umap"]:
439
  logger.error(
440
+ "Invalid option for 'plot_style'. " \
441
+ "Valid options: {'heatmap','umap'}"
442
  )
443
  raise
444
+
445
  if (plot_style == "umap") and (self.labels_to_plot is None):
446
+ logger.error(
447
+ "Plotting UMAP requires 'labels_to_plot'. "
448
+ )
449
  raise
450
+
451
+ if max_ncells_to_plot > self.max_ncells:
452
+ max_ncells_to_plot = self.max_ncells
453
+ logger.warning(
454
+ "max_ncells_to_plot must be <= max_ncells. " \
455
+ f"Changing max_ncells_to_plot to {self.max_ncells}.")
456
+
457
+ if (max_ncells_to_plot is not None) \
458
+ and (max_ncells_to_plot < self.max_ncells):
459
+ embs = embs.sample(max_ncells_to_plot, axis=0)
460
+
461
  if self.emb_label is None:
462
  label_len = 0
463
  else:
464
  label_len = len(self.emb_label)
465
+
466
  emb_dims = embs.shape[1] - label_len
467
+
468
  if self.emb_label is None:
469
  emb_labels = None
470
  else:
471
  emb_labels = embs.columns[emb_dims:]
472
+
473
  if plot_style == "umap":
474
  for label in self.labels_to_plot:
475
  if label not in emb_labels:
476
  logger.warning(
477
+ f"Label {label} from labels_to_plot " \
478
+ f"not present in provided embeddings dataframe.")
 
479
  continue
480
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
481
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
482
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
483
+
 
 
484
  if plot_style == "heatmap":
485
  for label in self.labels_to_plot:
486
  if label not in emb_labels:
487
  logger.warning(
488
+ f"Label {label} from labels_to_plot " \
489
+ f"not present in provided embeddings dataframe.")
 
490
  continue
491
  output_prefix_label = output_prefix + f"_heatmap_{label}"
492
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
493
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
 
geneformer/ensembl_mapping_dict_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
3
- size 3957652
 
 
 
 
geneformer/evaluation_utils.py DELETED
@@ -1,287 +0,0 @@
1
- import logging
2
- import math
3
- import pickle
4
- from pathlib import Path
5
-
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import pandas as pd
9
- import seaborn as sns
10
- import torch
11
- from datasets.utils.logging import disable_progress_bar, enable_progress_bar
12
- from sklearn import preprocessing
13
- from sklearn.metrics import (
14
- ConfusionMatrixDisplay,
15
- accuracy_score,
16
- auc,
17
- confusion_matrix,
18
- f1_score,
19
- roc_curve,
20
- )
21
- from tqdm.auto import trange
22
-
23
- from . import TOKEN_DICTIONARY_FILE
24
- from .emb_extractor import make_colorbar
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
- if max_len is None:
31
- max_len = max([len(i) for i in cell_batch["input_ids"]])
32
-
33
- # load token dictionary (Ensembl IDs:token)
34
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
- gene_token_dict = pickle.load(f)
36
-
37
- def pad_label_example(example):
38
- example[label_name] = np.pad(
39
- example[label_name],
40
- (0, max_len - len(example["input_ids"])),
41
- mode="constant",
42
- constant_values=-100,
43
- )
44
- example["input_ids"] = np.pad(
45
- example["input_ids"],
46
- (0, max_len - len(example["input_ids"])),
47
- mode="constant",
48
- constant_values=gene_token_dict.get("<pad>"),
49
- )
50
- example["attention_mask"] = (
51
- example["input_ids"] != gene_token_dict.get("<pad>")
52
- ).astype(int)
53
- return example
54
-
55
- padded_batch = cell_batch.map(pad_label_example)
56
- return padded_batch
57
-
58
-
59
- # Function to find the largest number smaller
60
- # than or equal to N that is divisible by k
61
- def find_largest_div(N, K):
62
- rem = N % K
63
- if rem == 0:
64
- return N
65
- else:
66
- return N - rem
67
-
68
-
69
- def vote(logit_list):
70
- m = max(logit_list)
71
- logit_list.index(m)
72
- indices = [i for i, x in enumerate(logit_list) if x == m]
73
- if len(indices) > 1:
74
- return "tie"
75
- else:
76
- return indices[0]
77
-
78
-
79
- def py_softmax(vector):
80
- e = np.exp(vector)
81
- return e / e.sum()
82
-
83
-
84
- def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
- if classifier_type == "gene":
86
- label_name = "labels"
87
- elif classifier_type == "cell":
88
- label_name = "label"
89
-
90
- predict_logits = []
91
- predict_labels = []
92
- model.eval()
93
-
94
- # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
95
- evalset_len = len(evalset)
96
- max_divisible = find_largest_div(evalset_len, forward_batch_size)
97
- if len(evalset) - max_divisible == 1:
98
- evalset_len = max_divisible
99
-
100
- max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
101
-
102
- disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
103
- for i in trange(0, evalset_len, forward_batch_size):
104
- max_range = min(i + forward_batch_size, evalset_len)
105
- batch_evalset = evalset.select([i for i in range(i, max_range)])
106
- padded_batch = preprocess_classifier_batch(
107
- batch_evalset, max_evalset_len, label_name
108
- )
109
- padded_batch.set_format(type="torch")
110
-
111
- input_data_batch = padded_batch["input_ids"]
112
- attn_msk_batch = padded_batch["attention_mask"]
113
- label_batch = padded_batch[label_name]
114
- with torch.no_grad():
115
- outputs = model(
116
- input_ids=input_data_batch.to("cuda"),
117
- attention_mask=attn_msk_batch.to("cuda"),
118
- labels=label_batch.to("cuda"),
119
- )
120
- predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
121
- predict_labels += [torch.squeeze(label_batch.to("cpu"))]
122
-
123
- enable_progress_bar()
124
- logits_by_cell = torch.cat(predict_logits)
125
- last_dim = len(logits_by_cell.shape) - 1
126
- all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
127
- labels_by_cell = torch.cat(predict_labels)
128
- all_labels = torch.flatten(labels_by_cell)
129
- logit_label_paired = [
130
- item
131
- for item in list(zip(all_logits.tolist(), all_labels.tolist()))
132
- if item[1] != -100
133
- ]
134
- y_pred = [vote(item[0]) for item in logit_label_paired]
135
- y_true = [item[1] for item in logit_label_paired]
136
- logits_list = [item[0] for item in logit_label_paired]
137
- return y_pred, y_true, logits_list
138
-
139
-
140
- def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
141
- conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
142
- macro_f1 = f1_score(y_true, y_pred, average="macro")
143
- acc = accuracy_score(y_true, y_pred)
144
- roc_metrics = None # roc metrics not reported for multiclass
145
- if num_classes == 2:
146
- y_score = [py_softmax(item)[1] for item in logits_list]
147
- fpr, tpr, _ = roc_curve(y_true, y_score)
148
- mean_fpr = np.linspace(0, 1, 100)
149
- interp_tpr = np.interp(mean_fpr, fpr, tpr)
150
- interp_tpr[0] = 0.0
151
- tpr_wt = len(tpr)
152
- roc_auc = auc(fpr, tpr)
153
- roc_metrics = {
154
- "fpr": fpr,
155
- "tpr": tpr,
156
- "interp_tpr": interp_tpr,
157
- "auc": roc_auc,
158
- "tpr_wt": tpr_wt,
159
- }
160
- return conf_mat, macro_f1, acc, roc_metrics
161
-
162
-
163
- # get cross-validated mean and sd metrics
164
- def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
165
- wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
166
- all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
167
- mean_tpr = np.sum(all_weighted_tpr, axis=0)
168
- mean_tpr[-1] = 1.0
169
- all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
170
- roc_auc = np.sum(all_weighted_roc_auc)
171
- roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
172
- return mean_tpr, roc_auc, roc_auc_sd
173
-
174
-
175
- # plot ROC curve
176
- def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
177
- fig = plt.figure()
178
- fig.set_size_inches(10, 8)
179
- sns.set(font_scale=2)
180
- sns.set_style("white")
181
- lw = 3
182
- for model_name in roc_metric_dict.keys():
183
- mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
- mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
185
- roc_auc = roc_metric_dict[model_name]["roc_auc"]
186
- roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
187
- color = model_style_dict[model_name]["color"]
188
- linestyle = model_style_dict[model_name]["linestyle"]
189
- if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
190
- label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
191
- else:
192
- label = f"{model_name} (AUC {roc_auc:0.2f})"
193
- plt.plot(
194
- mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
195
- )
196
-
197
- plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
198
- plt.xlim([0.0, 1.0])
199
- plt.ylim([0.0, 1.05])
200
- plt.xlabel("False Positive Rate")
201
- plt.ylabel("True Positive Rate")
202
- plt.title(title)
203
- plt.legend(loc="lower right")
204
-
205
- output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
206
- plt.savefig(output_file, bbox_inches="tight")
207
- plt.show()
208
-
209
-
210
- # plot confusion matrix
211
- def plot_confusion_matrix(
212
- conf_mat_df, title, output_dir, output_prefix, custom_class_order
213
- ):
214
- fig = plt.figure()
215
- fig.set_size_inches(10, 10)
216
- sns.set(font_scale=1)
217
- sns.set_style("whitegrid", {"axes.grid": False})
218
- if custom_class_order is not None:
219
- conf_mat_df = conf_mat_df.reindex(
220
- index=custom_class_order, columns=custom_class_order
221
- )
222
- display_labels = generate_display_labels(conf_mat_df)
223
- conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
224
- display = ConfusionMatrixDisplay(
225
- confusion_matrix=conf_mat, display_labels=display_labels
226
- )
227
- display.plot(cmap="Blues", values_format=".2g")
228
- plt.title(title)
229
- plt.show()
230
-
231
- output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
232
- display.figure_.savefig(output_file, bbox_inches="tight")
233
-
234
-
235
- def generate_display_labels(conf_mat_df):
236
- display_labels = []
237
- i = 0
238
- for label in conf_mat_df.index:
239
- display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
240
- i = i + 1
241
- return display_labels
242
-
243
-
244
- def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
245
- sns.set(font_scale=2)
246
- plt.figure(figsize=(10, 10), dpi=150)
247
- label_colors, label_color_dict = make_colorbar(predictions_df, "true")
248
- predictions_df = predictions_df.drop(columns=["true"])
249
- predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
250
- predict_label_list = [label for label in predictions_df.columns]
251
- predict_colors = pd.DataFrame(
252
- pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
253
- )
254
-
255
- default_kwargs_dict = {
256
- "row_cluster": False,
257
- "col_cluster": False,
258
- "row_colors": label_colors,
259
- "col_colors": predict_colors,
260
- "linewidths": 0,
261
- "xticklabels": False,
262
- "yticklabels": False,
263
- "center": 0,
264
- "cmap": "vlag",
265
- }
266
-
267
- if kwargs_dict is not None:
268
- default_kwargs_dict.update(kwargs_dict)
269
- g = sns.clustermap(predictions_df, **default_kwargs_dict)
270
-
271
- plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
272
-
273
- for label_color in list(label_color_dict.keys()):
274
- g.ax_col_dendrogram.bar(
275
- 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
276
- )
277
-
278
- g.ax_col_dendrogram.legend(
279
- title=f"{title}",
280
- loc="lower center",
281
- ncol=4,
282
- bbox_to_anchor=(0.5, 1),
283
- facecolor="white",
284
- )
285
-
286
- output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
287
- plt.savefig(output_file, bbox_inches="tight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eac0fb0b3007267871b6305ac0003ceba19d4f28d85686cb9067ecf142787869
3
- size 584125