gf_v1: filled in isp single gene code
#237
by
davidjwen
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitattributes +2 -2
- .pre-commit-config.yaml +0 -26
- .readthedocs.yaml +0 -19
- MANIFEST.in +3 -4
- README.md +11 -38
- config.json +8 -9
- docs/Makefile +0 -20
- docs/make.bat +0 -35
- docs/requirements.txt +0 -3
- docs/source/_static/css/custom.css +0 -40
- docs/source/_static/gf_logo.png +0 -0
- docs/source/about.rst +0 -49
- docs/source/api.rst +0 -51
- docs/source/conf.py +0 -80
- docs/source/geneformer.classifier.rst +0 -10
- docs/source/geneformer.emb_extractor.rst +0 -26
- docs/source/geneformer.in_silico_perturber.rst +0 -8
- docs/source/geneformer.in_silico_perturber_stats.rst +0 -25
- docs/source/geneformer.mtl_classifier.rst +0 -11
- docs/source/geneformer.tokenizer.rst +0 -15
- docs/source/getstarted.rst +0 -36
- docs/source/index.rst +0 -16
- examples/cell_classification.ipynb +0 -0
- examples/extract_and_plot_cell_embeddings.ipynb +4 -8
- examples/gene_classification.ipynb +0 -0
- examples/hyperparam_optimiz_for_disease_classifier.py +226 -0
- examples/in_silico_perturbation.ipynb +17 -66
- examples/multitask_cell_classification.ipynb +0 -420
- examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +1 -3
- examples/tokenizing_scRNAseq_data.ipynb +8 -27
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
- fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +0 -24
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +0 -3
- {gf-12L-30M-i2048 → geneformer-12L-30M}/config.json +0 -0
- {gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin +0 -0
- {gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin +0 -0
- geneformer/__init__.py +11 -33
- geneformer/classifier.py +0 -1563
- geneformer/classifier_utils.py +0 -648
- geneformer/collator_for_classification.py +74 -139
- geneformer/emb_extractor.py +279 -649
- geneformer/ensembl_mapping_dict_gc95M.pkl +0 -3
- geneformer/evaluation_utils.py +0 -287
- geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl +0 -3
.gitattributes
CHANGED
@@ -14,11 +14,10 @@
|
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
20 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
21 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
24 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
@@ -26,4 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
26 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
29 |
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
geneformer/gene_name_id_dict.pkl filter=lfs diff=lfs merge=lfs -text
|
29 |
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
.pre-commit-config.yaml
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
# See https://pre-commit.com for more information
|
2 |
-
# See https://pre-commit.com/hooks.html for more hooks
|
3 |
-
repos:
|
4 |
-
- repo: https://github.com/pre-commit/pre-commit-hooks
|
5 |
-
rev: v3.2.0
|
6 |
-
hooks:
|
7 |
-
- id: trailing-whitespace
|
8 |
-
- id: end-of-file-fixer
|
9 |
-
- id: check-yaml
|
10 |
-
- id: check-added-large-files
|
11 |
-
- id: check-merge-conflict
|
12 |
-
- id: mixed-line-ending
|
13 |
-
- id: check-docstring-first
|
14 |
-
- repo: https://github.com/pycqa/isort
|
15 |
-
rev: 5.12.0
|
16 |
-
hooks:
|
17 |
-
- id: isort
|
18 |
-
args: ["--profile", "black"]
|
19 |
-
- repo: https://github.com/astral-sh/ruff-pre-commit
|
20 |
-
# Ruff version.
|
21 |
-
rev: v0.1.4
|
22 |
-
hooks:
|
23 |
-
# Run the Ruff linter.
|
24 |
-
- id: ruff
|
25 |
-
# Run the Ruff formatter.
|
26 |
-
- id: ruff-format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.readthedocs.yaml
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
# Read the Docs configuration file
|
2 |
-
|
3 |
-
# Required
|
4 |
-
version: 2
|
5 |
-
|
6 |
-
# Set the OS, Python version and other tools you might need
|
7 |
-
build:
|
8 |
-
os: ubuntu-22.04
|
9 |
-
tools:
|
10 |
-
python: "3.10"
|
11 |
-
|
12 |
-
# Build documentation in the "docs/" directory with Sphinx
|
13 |
-
sphinx:
|
14 |
-
configuration: docs/source/conf.py
|
15 |
-
|
16 |
-
# Python requirements required build your documentation
|
17 |
-
python:
|
18 |
-
install:
|
19 |
-
- requirements: docs/requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MANIFEST.in
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
include geneformer/
|
2 |
-
include geneformer/
|
3 |
-
include geneformer/
|
4 |
-
include geneformer/token_dictionary_gc95M.pkl
|
|
|
1 |
+
include geneformer/gene_median_dictionary.pkl
|
2 |
+
include geneformer/token_dictionary.pkl
|
3 |
+
include geneformer/gene_name_id_dict.pkl
|
|
README.md
CHANGED
@@ -1,43 +1,22 @@
|
|
1 |
---
|
2 |
datasets: ctheodoris/Genecorpus-30M
|
3 |
license: apache-2.0
|
4 |
-
tags:
|
5 |
-
- single-cell
|
6 |
-
- genomics
|
7 |
---
|
8 |
# Geneformer
|
9 |
-
Geneformer is a
|
10 |
|
11 |
-
|
12 |
-
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
13 |
-
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
14 |
|
15 |
# Model Description
|
16 |
-
Geneformer is a
|
17 |
|
18 |
-
|
19 |
|
20 |
-
|
21 |
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
The repository includes the following pretrained models:
|
27 |
-
|
28 |
-
L=layers\
|
29 |
-
M=millions of cells used for pretraining\
|
30 |
-
i=input size\
|
31 |
-
(pretraining date)
|
32 |
-
|
33 |
-
- GF-6L-30M-i2048 (June 2021)
|
34 |
-
- GF-12L-30M-i2048 (June 2021)
|
35 |
-
- GF-12L-95M-i4096 (April 2024)
|
36 |
-
- GF-20L-95M-i4096 (April 2024)
|
37 |
-
|
38 |
-
The current default model in the main directory of the repository is GF-12L-95M-i4096.
|
39 |
-
|
40 |
-
The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
|
41 |
|
42 |
# Application
|
43 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
@@ -45,7 +24,7 @@ The pretrained Geneformer model can be used directly for zero-shot learning, for
|
|
45 |
Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
|
46 |
|
47 |
*Fine-tuning*:
|
48 |
-
- transcription factor dosage sensitivity
|
49 |
- chromatin dynamics (bivalently marked promoters)
|
50 |
- transcription factor regulatory range
|
51 |
- gene network centrality
|
@@ -67,11 +46,9 @@ Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) inc
|
|
67 |
- in silico perturbation to determine transcription factor cooperativity
|
68 |
|
69 |
# Installation
|
70 |
-
In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install
|
71 |
|
72 |
```bash
|
73 |
-
# Make sure you have git-lfs installed (https://git-lfs.com)
|
74 |
-
git lfs install
|
75 |
git clone https://huggingface.co/ctheodoris/Geneformer
|
76 |
cd Geneformer
|
77 |
pip install .
|
@@ -85,10 +62,6 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
|
|
85 |
- extracting and plotting cell embeddings
|
86 |
- in silico perturbation
|
87 |
|
88 |
-
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
89 |
-
|
90 |
-
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
91 |
|
92 |
-
|
93 |
-
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
94 |
-
- H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
|
|
|
1 |
---
|
2 |
datasets: ctheodoris/Genecorpus-30M
|
3 |
license: apache-2.0
|
|
|
|
|
|
|
4 |
---
|
5 |
# Geneformer
|
6 |
+
Geneformer is a foundation transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
7 |
|
8 |
+
See [our manuscript](https://rdcu.be/ddrx0) for details.
|
|
|
|
|
9 |
|
10 |
# Model Description
|
11 |
+
Geneformer is a foundation transformer model pretrained on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
12 |
|
13 |
+
The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
14 |
|
15 |
+
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
16 |
|
17 |
+
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents a pretrained deep learning model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
|
18 |
|
19 |
+
In [our manuscript](https://rdcu.be/ddrx0), we report results for the 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within this repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Application
|
22 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
|
|
24 |
Example applications demonstrated in [our manuscript](https://rdcu.be/ddrx0) include:
|
25 |
|
26 |
*Fine-tuning*:
|
27 |
+
- transcription factor dosage sensitivity
|
28 |
- chromatin dynamics (bivalently marked promoters)
|
29 |
- transcription factor regulatory range
|
30 |
- gene network centrality
|
|
|
46 |
- in silico perturbation to determine transcription factor cooperativity
|
47 |
|
48 |
# Installation
|
49 |
+
In addition to the pretrained model, contained herein are functions for tokenizing and collating data specific to single cell transcriptomics, pretraining the model, fine-tuning the model, extracting and plotting cell embeddings, and performing in silico pertrubation with either the pretrained or fine-tuned models. To install:
|
50 |
|
51 |
```bash
|
|
|
|
|
52 |
git clone https://huggingface.co/ctheodoris/Geneformer
|
53 |
cd Geneformer
|
54 |
pip install .
|
|
|
62 |
- extracting and plotting cell embeddings
|
63 |
- in silico perturbation
|
64 |
|
65 |
+
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
|
|
|
|
66 |
|
67 |
+
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
|
|
|
config.json
CHANGED
@@ -3,22 +3,21 @@
|
|
3 |
"BertForMaskedLM"
|
4 |
],
|
5 |
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"
|
7 |
"hidden_act": "relu",
|
8 |
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size":
|
10 |
"initializer_range": 0.02,
|
11 |
-
"intermediate_size":
|
12 |
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings":
|
14 |
"model_type": "bert",
|
15 |
-
"num_attention_heads":
|
16 |
-
"num_hidden_layers":
|
17 |
"pad_token_id": 0,
|
18 |
"position_embedding_type": "absolute",
|
19 |
-
"
|
20 |
-
"transformers_version": "4.37.1",
|
21 |
"type_vocab_size": 2,
|
22 |
"use_cache": true,
|
23 |
-
"vocab_size":
|
24 |
}
|
|
|
3 |
"BertForMaskedLM"
|
4 |
],
|
5 |
"attention_probs_dropout_prob": 0.02,
|
6 |
+
"gradient_checkpointing": false,
|
7 |
"hidden_act": "relu",
|
8 |
"hidden_dropout_prob": 0.02,
|
9 |
+
"hidden_size": 256,
|
10 |
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 512,
|
12 |
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 2048,
|
14 |
"model_type": "bert",
|
15 |
+
"num_attention_heads": 4,
|
16 |
+
"num_hidden_layers": 6,
|
17 |
"pad_token_id": 0,
|
18 |
"position_embedding_type": "absolute",
|
19 |
+
"transformers_version": "4.6.0",
|
|
|
20 |
"type_vocab_size": 2,
|
21 |
"use_cache": true,
|
22 |
+
"vocab_size": 25426
|
23 |
}
|
docs/Makefile
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
# Minimal makefile for Sphinx documentation
|
2 |
-
#
|
3 |
-
|
4 |
-
# You can set these variables from the command line, and also
|
5 |
-
# from the environment for the first two.
|
6 |
-
SPHINXOPTS ?=
|
7 |
-
SPHINXBUILD ?= sphinx-build
|
8 |
-
SOURCEDIR = source
|
9 |
-
BUILDDIR = build
|
10 |
-
|
11 |
-
# Put it first so that "make" without argument is like "make help".
|
12 |
-
help:
|
13 |
-
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14 |
-
|
15 |
-
.PHONY: help Makefile
|
16 |
-
|
17 |
-
# Catch-all target: route all unknown targets to Sphinx using the new
|
18 |
-
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
19 |
-
%: Makefile
|
20 |
-
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/make.bat
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
@ECHO OFF
|
2 |
-
|
3 |
-
pushd %~dp0
|
4 |
-
|
5 |
-
REM Command file for Sphinx documentation
|
6 |
-
|
7 |
-
if "%SPHINXBUILD%" == "" (
|
8 |
-
set SPHINXBUILD=sphinx-build
|
9 |
-
)
|
10 |
-
set SOURCEDIR=source
|
11 |
-
set BUILDDIR=build
|
12 |
-
|
13 |
-
%SPHINXBUILD% >NUL 2>NUL
|
14 |
-
if errorlevel 9009 (
|
15 |
-
echo.
|
16 |
-
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
17 |
-
echo.installed, then set the SPHINXBUILD environment variable to point
|
18 |
-
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
19 |
-
echo.may add the Sphinx directory to PATH.
|
20 |
-
echo.
|
21 |
-
echo.If you don't have Sphinx installed, grab it from
|
22 |
-
echo.https://www.sphinx-doc.org/
|
23 |
-
exit /b 1
|
24 |
-
)
|
25 |
-
|
26 |
-
if "%1" == "" goto help
|
27 |
-
|
28 |
-
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
29 |
-
goto end
|
30 |
-
|
31 |
-
:help
|
32 |
-
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
33 |
-
|
34 |
-
:end
|
35 |
-
popd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/requirements.txt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
.
|
2 |
-
sphinx_rtd_theme==2.0.0
|
3 |
-
nbsphinx==0.9.3
|
|
|
|
|
|
|
|
docs/source/_static/css/custom.css
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
/* top left logo */
|
2 |
-
.wy-side-nav-search, .wy-nav-top {
|
3 |
-
background: linear-gradient(15deg, #13547a 0%, #80d0c7 100%);
|
4 |
-
}
|
5 |
-
|
6 |
-
|
7 |
-
/* unvisited link */
|
8 |
-
.wy-nav-content a:link {
|
9 |
-
color: #067abd;
|
10 |
-
}
|
11 |
-
|
12 |
-
/* visited link */
|
13 |
-
.wy-nav-content a:visited {
|
14 |
-
color: #4b827c;
|
15 |
-
}
|
16 |
-
|
17 |
-
/* mouse over link */
|
18 |
-
.wy-nav-content a:hover {
|
19 |
-
color: #80d0c7;
|
20 |
-
}
|
21 |
-
|
22 |
-
/* selected link */
|
23 |
-
.wy-nav-content a:active {
|
24 |
-
color: #4b827c;
|
25 |
-
}
|
26 |
-
|
27 |
-
/* class object */
|
28 |
-
.sig.sig-object {
|
29 |
-
padding: 5px 5px 5px 5px;
|
30 |
-
background-color: #ececec;
|
31 |
-
border-style: solid;
|
32 |
-
border-color: black;
|
33 |
-
border-width: 1px 0;
|
34 |
-
}
|
35 |
-
|
36 |
-
/* parameter object */
|
37 |
-
dt {
|
38 |
-
padding: 5px 5px 5px 5px;
|
39 |
-
background-color: #ececec;
|
40 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/_static/gf_logo.png
DELETED
Binary file (48.2 kB)
|
|
docs/source/about.rst
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
About
|
2 |
-
=====
|
3 |
-
|
4 |
-
Model Description
|
5 |
-
-----------------
|
6 |
-
|
7 |
-
**Geneformer** is a context-aware, attention-based deep learning model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an iPSC model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus of human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
8 |
-
|
9 |
-
In `our manuscript <https://rdcu.be/ddrx0>`_, we report results for the original 6 layer Geneformer model pretrained on Genecorpus-30M. We additionally provide within the repository a 12 layer Geneformer model, scaled up with retained width:depth aspect ratio, also pretrained on Genecorpus-30M.
|
10 |
-
|
11 |
-
Both the `6 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors>`_ and `12 <https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-12L-30M-i2048/pytorch_model.bin>`_ layer Geneformer models were pretrained in June 2021.
|
12 |
-
|
13 |
-
Also see `our 2024 manuscript <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_, for details of the `expanded model <https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors>`_ trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
14 |
-
|
15 |
-
Application
|
16 |
-
-----------
|
17 |
-
|
18 |
-
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
19 |
-
|
20 |
-
Example applications demonstrated in `our manuscript <https://rdcu.be/ddrx0>`_ include:
|
21 |
-
|
22 |
-
| *Fine-tuning*:
|
23 |
-
| - transcription factor dosage sensitivity
|
24 |
-
| - chromatin dynamics (bivalently marked promoters)
|
25 |
-
| - transcription factor regulatory range
|
26 |
-
| - gene network centrality
|
27 |
-
| - transcription factor targets
|
28 |
-
| - cell type annotation
|
29 |
-
| - batch integration
|
30 |
-
| - cell state classification across differentiation
|
31 |
-
| - disease classification
|
32 |
-
| - in silico perturbation to determine disease-driving genes
|
33 |
-
| - in silico treatment to determine candidate therapeutic targets
|
34 |
-
|
35 |
-
| *Zero-shot learning*:
|
36 |
-
| - batch integration
|
37 |
-
| - gene context specificity
|
38 |
-
| - in silico reprogramming
|
39 |
-
| - in silico differentiation
|
40 |
-
| - in silico perturbation to determine impact on cell state
|
41 |
-
| - in silico perturbation to determine transcription factor targets
|
42 |
-
| - in silico perturbation to determine transcription factor cooperativity
|
43 |
-
|
44 |
-
Citations
|
45 |
-
---------
|
46 |
-
|
47 |
-
| C V Theodoris #, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor #. `Transfer learning enables predictions in network biology. <https://rdcu.be/ddrx0>`_ *Nature*, 31 May 2023. (# co-corresponding authors)
|
48 |
-
|
49 |
-
| H Chen \*, M S Venkatesh \*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka †, C V Theodoris † #. `Quantized multi-task learning for context-specific representations of gene network dynamics. <https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf>`_ *bioRxiv*, 19 Aug 2024. (\* co-first authors, † co-senior authors, # corresponding author)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/api.rst
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
API
|
2 |
-
===
|
3 |
-
|
4 |
-
Tokenizer
|
5 |
-
---------
|
6 |
-
|
7 |
-
.. toctree::
|
8 |
-
:maxdepth: 1
|
9 |
-
|
10 |
-
geneformer.tokenizer
|
11 |
-
|
12 |
-
Classifier
|
13 |
-
----------
|
14 |
-
|
15 |
-
.. toctree::
|
16 |
-
:maxdepth: 1
|
17 |
-
|
18 |
-
geneformer.classifier
|
19 |
-
|
20 |
-
Multitask Classifier
|
21 |
-
--------------------
|
22 |
-
|
23 |
-
.. toctree::
|
24 |
-
:maxdepth: 1
|
25 |
-
|
26 |
-
geneformer.mtl_classifier
|
27 |
-
|
28 |
-
Embedding Extractor
|
29 |
-
-------------------
|
30 |
-
|
31 |
-
.. toctree::
|
32 |
-
:maxdepth: 1
|
33 |
-
|
34 |
-
geneformer.emb_extractor
|
35 |
-
|
36 |
-
In Silico Perturber
|
37 |
-
-------------------
|
38 |
-
|
39 |
-
.. toctree::
|
40 |
-
:maxdepth: 1
|
41 |
-
|
42 |
-
geneformer.in_silico_perturber
|
43 |
-
|
44 |
-
|
45 |
-
In Silico Perturber Stats
|
46 |
-
-------------------------
|
47 |
-
|
48 |
-
.. toctree::
|
49 |
-
:maxdepth: 1
|
50 |
-
|
51 |
-
geneformer.in_silico_perturber_stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/conf.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
# Configuration file for the Sphinx documentation builder.
|
2 |
-
#
|
3 |
-
# For the full list of built-in configuration values, see the documentation:
|
4 |
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
5 |
-
|
6 |
-
import pathlib
|
7 |
-
import re
|
8 |
-
import sys
|
9 |
-
|
10 |
-
from sphinx.ext import autodoc
|
11 |
-
|
12 |
-
sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix())
|
13 |
-
|
14 |
-
|
15 |
-
# -- Project information -----------------------------------------------------
|
16 |
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
17 |
-
|
18 |
-
project = "geneformer"
|
19 |
-
copyright = "2024, Christina Theodoris"
|
20 |
-
author = "Christina Theodoris"
|
21 |
-
release = "0.1.0"
|
22 |
-
repository_url = "https://huggingface.co/ctheodoris/Geneformer"
|
23 |
-
|
24 |
-
# -- General configuration ---------------------------------------------------
|
25 |
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
26 |
-
|
27 |
-
extensions = [
|
28 |
-
"sphinx.ext.autodoc",
|
29 |
-
"sphinx.ext.autosummary",
|
30 |
-
"nbsphinx",
|
31 |
-
"sphinx.ext.viewcode",
|
32 |
-
"sphinx.ext.doctest",
|
33 |
-
]
|
34 |
-
|
35 |
-
templates_path = ["_templates"]
|
36 |
-
exclude_patterns = [
|
37 |
-
"**.ipynb_checkpoints",
|
38 |
-
]
|
39 |
-
autoclass_content = "both"
|
40 |
-
|
41 |
-
|
42 |
-
class MockedClassDocumenter(autodoc.ClassDocumenter):
|
43 |
-
def add_line(self, line: str, source: str, *lineno: int) -> None:
|
44 |
-
if line == " Bases: :py:class:`object`":
|
45 |
-
return
|
46 |
-
super().add_line(line, source, *lineno)
|
47 |
-
|
48 |
-
|
49 |
-
autodoc.ClassDocumenter = MockedClassDocumenter
|
50 |
-
add_module_names = False
|
51 |
-
|
52 |
-
|
53 |
-
def process_signature(app, what, name, obj, options, signature, return_annotation):
|
54 |
-
# loop through each line in the docstring and replace path with
|
55 |
-
# the generic path text
|
56 |
-
signature = re.sub(r"PosixPath\(.*?\)", "FILEPATH", signature)
|
57 |
-
return (signature, None)
|
58 |
-
|
59 |
-
|
60 |
-
def setup(app):
|
61 |
-
app.connect("autodoc-process-signature", process_signature)
|
62 |
-
|
63 |
-
|
64 |
-
# -- Options for HTML output -------------------------------------------------
|
65 |
-
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
66 |
-
|
67 |
-
html_theme = "sphinx_rtd_theme"
|
68 |
-
html_show_sphinx = False
|
69 |
-
html_static_path = ["_static"]
|
70 |
-
html_logo = "_static/gf_logo.png"
|
71 |
-
html_theme_options = {
|
72 |
-
"collapse_navigation": False,
|
73 |
-
"sticky_navigation": True,
|
74 |
-
"navigation_depth": 3,
|
75 |
-
"logo_only": True,
|
76 |
-
}
|
77 |
-
html_css_files = [
|
78 |
-
"css/custom.css",
|
79 |
-
]
|
80 |
-
html_show_sourcelink = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.classifier.rst
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
geneformer.classifier
|
2 |
-
=====================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.classifier
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members:
|
9 |
-
valid_option_dict,
|
10 |
-
validate_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.emb_extractor.rst
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
geneformer.emb\_extractor
|
2 |
-
=========================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.emb_extractor
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members:
|
9 |
-
accumulate_tdigests,
|
10 |
-
gen_heatmap_class_colors,
|
11 |
-
gen_heatmap_class_dict,
|
12 |
-
get_embs,
|
13 |
-
label_cell_embs,
|
14 |
-
label_gene_embs,
|
15 |
-
make_colorbar,
|
16 |
-
plot_heatmap,
|
17 |
-
plot_umap,
|
18 |
-
summarize_gene_embs,
|
19 |
-
tdigest_mean,
|
20 |
-
tdigest_median,
|
21 |
-
test_emb,
|
22 |
-
update_tdigest_dict,
|
23 |
-
update_tdigest_dict_mean,
|
24 |
-
update_tdigest_dict_median,
|
25 |
-
valid_option_dict,
|
26 |
-
validate_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.in_silico_perturber.rst
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
geneformer.in\_silico\_perturber
|
2 |
-
=======================================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.in_silico_perturber
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.in_silico_perturber_stats.rst
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
geneformer.in\_silico\_perturber\_stats
|
2 |
-
==============================================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.in_silico_perturber_stats
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members:
|
9 |
-
find,
|
10 |
-
get_fdr,
|
11 |
-
get_gene_list,
|
12 |
-
get_impact_component,
|
13 |
-
invert_dict,
|
14 |
-
isp_aggregate_gene_shifts,
|
15 |
-
isp_aggregate_grouped_perturb,
|
16 |
-
isp_stats_mixture_model,
|
17 |
-
isp_stats_to_goal_state,
|
18 |
-
isp_stats_vs_null,
|
19 |
-
n_detections,
|
20 |
-
read_dict,
|
21 |
-
read_dictionaries,
|
22 |
-
token_to_gene_name,
|
23 |
-
token_tuple_to_ensembl_ids,
|
24 |
-
valid_option_dict,
|
25 |
-
validate_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.mtl_classifier.rst
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
geneformer.mtl\_classifier
|
2 |
-
==========================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.mtl_classifier
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members:
|
9 |
-
valid_option_dict,
|
10 |
-
validate_options,
|
11 |
-
validate_additional_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/geneformer.tokenizer.rst
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
geneformer.tokenizer
|
2 |
-
====================
|
3 |
-
|
4 |
-
.. automodule:: geneformer.tokenizer
|
5 |
-
:members:
|
6 |
-
:undoc-members:
|
7 |
-
:show-inheritance:
|
8 |
-
:exclude-members:
|
9 |
-
create_dataset,
|
10 |
-
tokenize_anndata,
|
11 |
-
tokenize_files,
|
12 |
-
tokenize_loom,
|
13 |
-
rank_genes,
|
14 |
-
tokenize_cell,
|
15 |
-
sum_ensembl_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/getstarted.rst
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
Getting Started
|
2 |
-
===============
|
3 |
-
|
4 |
-
Installation
|
5 |
-
------------
|
6 |
-
|
7 |
-
Geneformer installation instructions.
|
8 |
-
|
9 |
-
Make sure you have git-lfs installed (https://git-lfs.com).
|
10 |
-
|
11 |
-
.. code-block:: bash
|
12 |
-
|
13 |
-
git lfs install
|
14 |
-
git clone https://huggingface.co/ctheodoris/Geneformer
|
15 |
-
cd Geneformer
|
16 |
-
pip install .
|
17 |
-
|
18 |
-
|
19 |
-
Tutorials
|
20 |
-
---------
|
21 |
-
|
22 |
-
| See `examples <https://huggingface.co/ctheodoris/Geneformer/tree/main/examples>`_ for:
|
23 |
-
| - tokenizing transcriptomes
|
24 |
-
| - pretraining
|
25 |
-
| - hyperparameter tuning
|
26 |
-
| - fine-tuning
|
27 |
-
| - extracting and plotting cell embeddings
|
28 |
-
| - in silico perturbation
|
29 |
-
|
30 |
-
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the `example_input_files directory <https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files>`_ in the dataset repository, but these only represent a few example fine-tuning applications.
|
31 |
-
|
32 |
-
|
33 |
-
Tips
|
34 |
-
----
|
35 |
-
|
36 |
-
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/source/index.rst
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
Geneformer
|
2 |
-
==========
|
3 |
-
|
4 |
-
Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in network biology.
|
5 |
-
|
6 |
-
See `our manuscript <https://rdcu.be/ddrx0>`_ for details.
|
7 |
-
|
8 |
-
Table of Contents
|
9 |
-
-----------------
|
10 |
-
|
11 |
-
.. toctree::
|
12 |
-
:maxdepth: 2
|
13 |
-
|
14 |
-
about
|
15 |
-
getstarted
|
16 |
-
api
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/cell_classification.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/extract_and_plot_cell_embeddings.ipynb
CHANGED
@@ -18,8 +18,6 @@
|
|
18 |
"outputs": [],
|
19 |
"source": [
|
20 |
"# initiate EmbExtractor\n",
|
21 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
22 |
-
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
23 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
24 |
" num_classes=3,\n",
|
25 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
@@ -28,13 +26,11 @@
|
|
28 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
29 |
" labels_to_plot=[\"disease\"],\n",
|
30 |
" forward_batch_size=200,\n",
|
31 |
-
" nproc=16
|
32 |
-
" token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n",
|
33 |
"\n",
|
34 |
"# extracts embedding from input data\n",
|
35 |
-
"#
|
36 |
-
"
|
37 |
-
"embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
|
38 |
" \"path/to/input_data/\",\n",
|
39 |
" \"path/to/output_directory/\",\n",
|
40 |
" \"output_prefix\")\n"
|
@@ -132,7 +128,7 @@
|
|
132 |
"name": "python",
|
133 |
"nbconvert_exporter": "python",
|
134 |
"pygments_lexer": "ipython3",
|
135 |
-
"version": "3.10.
|
136 |
}
|
137 |
},
|
138 |
"nbformat": 4,
|
|
|
18 |
"outputs": [],
|
19 |
"source": [
|
20 |
"# initiate EmbExtractor\n",
|
|
|
|
|
21 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
22 |
" num_classes=3,\n",
|
23 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
|
|
26 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
27 |
" labels_to_plot=[\"disease\"],\n",
|
28 |
" forward_batch_size=200,\n",
|
29 |
+
" nproc=16)\n",
|
|
|
30 |
"\n",
|
31 |
"# extracts embedding from input data\n",
|
32 |
+
"# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
33 |
+
"embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
|
|
|
34 |
" \"path/to/input_data/\",\n",
|
35 |
" \"path/to/output_directory/\",\n",
|
36 |
" \"output_prefix\")\n"
|
|
|
128 |
"name": "python",
|
129 |
"nbconvert_exporter": "python",
|
130 |
"pygments_lexer": "ipython3",
|
131 |
+
"version": "3.10.11"
|
132 |
}
|
133 |
},
|
134 |
"nbformat": 4,
|
examples/gene_classification.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/hyperparam_optimiz_for_disease_classifier.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# hyperparameter optimization with raytune for disease classification
|
5 |
+
|
6 |
+
# imports
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
GPU_NUMBER = [0,1,2,3]
|
10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
11 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
12 |
+
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
13 |
+
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
|
14 |
+
|
15 |
+
# initiate runtime environment for raytune
|
16 |
+
import pyarrow # must occur prior to ray import
|
17 |
+
import ray
|
18 |
+
from ray import tune
|
19 |
+
from ray.tune import ExperimentAnalysis
|
20 |
+
from ray.tune.suggest.hyperopt import HyperOptSearch
|
21 |
+
ray.shutdown() #engage new ray session
|
22 |
+
runtime_env = {"conda": "base",
|
23 |
+
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
|
24 |
+
ray.init(runtime_env=runtime_env)
|
25 |
+
|
26 |
+
def initialize_ray_with_check(ip_address):
|
27 |
+
"""
|
28 |
+
Initialize Ray with a specified IP address and check its status and accessibility.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
- ip_address (str): The IP address (with port) to initialize Ray.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
- bool: True if initialization was successful and dashboard is accessible, False otherwise.
|
35 |
+
"""
|
36 |
+
try:
|
37 |
+
ray.init(address=ip_address)
|
38 |
+
print(ray.nodes())
|
39 |
+
|
40 |
+
services = ray.get_webui_url()
|
41 |
+
if not services:
|
42 |
+
raise RuntimeError("Ray dashboard is not accessible.")
|
43 |
+
else:
|
44 |
+
print(f"Ray dashboard is accessible at: {services}")
|
45 |
+
return True
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error initializing Ray: {e}")
|
48 |
+
return False
|
49 |
+
|
50 |
+
# Usage:
|
51 |
+
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
52 |
+
if initialize_ray_with_check(ip):
|
53 |
+
print("Ray initialized successfully.")
|
54 |
+
else:
|
55 |
+
print("Error during Ray initialization.")
|
56 |
+
|
57 |
+
import datetime
|
58 |
+
import numpy as np
|
59 |
+
import pandas as pd
|
60 |
+
import random
|
61 |
+
import seaborn as sns; sns.set()
|
62 |
+
from collections import Counter
|
63 |
+
from datasets import load_from_disk
|
64 |
+
from scipy.stats import ranksums
|
65 |
+
from sklearn.metrics import accuracy_score
|
66 |
+
from transformers import BertForSequenceClassification
|
67 |
+
from transformers import Trainer
|
68 |
+
from transformers.training_args import TrainingArguments
|
69 |
+
|
70 |
+
from geneformer import DataCollatorForCellClassification
|
71 |
+
|
72 |
+
# number of CPU cores
|
73 |
+
num_proc=30
|
74 |
+
|
75 |
+
# load train dataset with columns:
|
76 |
+
# cell_type (annotation of each cell's type)
|
77 |
+
# disease (healthy or disease state)
|
78 |
+
# individual (unique ID for each patient)
|
79 |
+
# length (length of that cell's rank value encoding)
|
80 |
+
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
|
81 |
+
|
82 |
+
# filter dataset for given cell_type
|
83 |
+
def if_cell_type(example):
|
84 |
+
return example["cell_type"].startswith("Cardiomyocyte")
|
85 |
+
|
86 |
+
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
|
87 |
+
|
88 |
+
# create dictionary of disease states : label ids
|
89 |
+
target_names = ["healthy", "disease1", "disease2"]
|
90 |
+
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
91 |
+
|
92 |
+
trainset_v3 = trainset_v2.rename_column("disease","label")
|
93 |
+
|
94 |
+
# change labels to numerical ids
|
95 |
+
def classes_to_ids(example):
|
96 |
+
example["label"] = target_name_id_dict[example["label"]]
|
97 |
+
return example
|
98 |
+
|
99 |
+
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
|
100 |
+
|
101 |
+
# separate into train, validation, test sets
|
102 |
+
indiv_set = set(trainset_v4["individual"])
|
103 |
+
random.seed(42)
|
104 |
+
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
|
105 |
+
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
|
106 |
+
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
|
107 |
+
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
|
108 |
+
|
109 |
+
def if_train(example):
|
110 |
+
return example["individual"] in train_indiv
|
111 |
+
|
112 |
+
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
|
113 |
+
|
114 |
+
def if_valid(example):
|
115 |
+
return example["individual"] in valid_indiv
|
116 |
+
|
117 |
+
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
|
118 |
+
|
119 |
+
# define output directory path
|
120 |
+
current_date = datetime.datetime.now()
|
121 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
122 |
+
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
|
123 |
+
|
124 |
+
# ensure not overwriting previously saved model
|
125 |
+
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
126 |
+
if os.path.isfile(saved_model_test) == True:
|
127 |
+
raise Exception("Model already saved to this directory.")
|
128 |
+
|
129 |
+
# make output directory
|
130 |
+
subprocess.call(f'mkdir {output_dir}', shell=True)
|
131 |
+
|
132 |
+
# set training parameters
|
133 |
+
# how many pretrained layers to freeze
|
134 |
+
freeze_layers = 2
|
135 |
+
# batch size for training and eval
|
136 |
+
geneformer_batch_size = 12
|
137 |
+
# number of epochs
|
138 |
+
epochs = 1
|
139 |
+
# logging steps
|
140 |
+
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
|
141 |
+
|
142 |
+
# define function to initiate model
|
143 |
+
def model_init():
|
144 |
+
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
|
145 |
+
num_labels=len(target_names),
|
146 |
+
output_attentions = False,
|
147 |
+
output_hidden_states = False)
|
148 |
+
if freeze_layers is not None:
|
149 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
150 |
+
for module in modules_to_freeze:
|
151 |
+
for param in module.parameters():
|
152 |
+
param.requires_grad = False
|
153 |
+
|
154 |
+
model = model.to("cuda:0")
|
155 |
+
return model
|
156 |
+
|
157 |
+
# define metrics
|
158 |
+
# note: macro f1 score recommended for imbalanced multiclass classifiers
|
159 |
+
def compute_metrics(pred):
|
160 |
+
labels = pred.label_ids
|
161 |
+
preds = pred.predictions.argmax(-1)
|
162 |
+
# calculate accuracy using sklearn's function
|
163 |
+
acc = accuracy_score(labels, preds)
|
164 |
+
return {
|
165 |
+
'accuracy': acc,
|
166 |
+
}
|
167 |
+
|
168 |
+
# set training arguments
|
169 |
+
training_args = {
|
170 |
+
"do_train": True,
|
171 |
+
"do_eval": True,
|
172 |
+
"evaluation_strategy": "steps",
|
173 |
+
"eval_steps": logging_steps,
|
174 |
+
"logging_steps": logging_steps,
|
175 |
+
"group_by_length": True,
|
176 |
+
"length_column_name": "length",
|
177 |
+
"disable_tqdm": True,
|
178 |
+
"skip_memory_metrics": True, # memory tracker causes errors in raytune
|
179 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
180 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
181 |
+
"num_train_epochs": epochs,
|
182 |
+
"load_best_model_at_end": True,
|
183 |
+
"output_dir": output_dir,
|
184 |
+
}
|
185 |
+
|
186 |
+
training_args_init = TrainingArguments(**training_args)
|
187 |
+
|
188 |
+
# create the trainer
|
189 |
+
trainer = Trainer(
|
190 |
+
model_init=model_init,
|
191 |
+
args=training_args_init,
|
192 |
+
data_collator=DataCollatorForCellClassification(),
|
193 |
+
train_dataset=classifier_trainset,
|
194 |
+
eval_dataset=classifier_validset,
|
195 |
+
compute_metrics=compute_metrics,
|
196 |
+
)
|
197 |
+
|
198 |
+
# specify raytune hyperparameter search space
|
199 |
+
ray_config = {
|
200 |
+
"num_train_epochs": tune.choice([epochs]),
|
201 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
202 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
203 |
+
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
204 |
+
"warmup_steps": tune.uniform(100, 2000),
|
205 |
+
"seed": tune.uniform(0,100),
|
206 |
+
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
207 |
+
}
|
208 |
+
|
209 |
+
hyperopt_search = HyperOptSearch(
|
210 |
+
metric="eval_accuracy", mode="max")
|
211 |
+
|
212 |
+
# optimize hyperparameters
|
213 |
+
trainer.hyperparameter_search(
|
214 |
+
direction="maximize",
|
215 |
+
backend="ray",
|
216 |
+
resources_per_trial={"cpu":8,"gpu":1},
|
217 |
+
hp_space=lambda _: ray_config,
|
218 |
+
search_alg=hyperopt_search,
|
219 |
+
n_trials=100, # number of trials
|
220 |
+
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
221 |
+
sort_by_metric=True,
|
222 |
+
max_progress_rows=100,
|
223 |
+
mode="max",
|
224 |
+
metric="eval_accuracy",
|
225 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"])
|
226 |
+
)
|
examples/in_silico_perturbation.ipynb
CHANGED
@@ -8,80 +8,35 @@
|
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
"from geneformer import InSilicoPerturber\n",
|
11 |
-
"from geneformer import InSilicoPerturberStats
|
12 |
-
"from geneformer import EmbExtractor"
|
13 |
-
]
|
14 |
-
},
|
15 |
-
{
|
16 |
-
"cell_type": "markdown",
|
17 |
-
"id": "cbd6851c-060e-4967-b816-e605ffe58b23",
|
18 |
-
"metadata": {
|
19 |
-
"tags": []
|
20 |
-
},
|
21 |
-
"source": [
|
22 |
-
"### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
|
23 |
-
]
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"cell_type": "code",
|
27 |
-
"execution_count": null,
|
28 |
-
"id": "c53e98cd-c603-4878-82ba-db471181bb55",
|
29 |
-
"metadata": {},
|
30 |
-
"outputs": [],
|
31 |
-
"source": [
|
32 |
-
"# first obtain start, goal, and alt embedding positions\n",
|
33 |
-
"# this function was changed to be separate from perturb_data\n",
|
34 |
-
"# to avoid repeating calcuations when parallelizing perturb_data\n",
|
35 |
-
"cell_states_to_model={\"state_key\": \"disease\", \n",
|
36 |
-
" \"start_state\": \"dcm\", \n",
|
37 |
-
" \"goal_state\": \"nf\", \n",
|
38 |
-
" \"alt_states\": [\"hcm\"]}\n",
|
39 |
-
"\n",
|
40 |
-
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
-
"\n",
|
42 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
43 |
-
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
44 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
45 |
-
"embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
46 |
-
" num_classes=3,\n",
|
47 |
-
" filter_data=filter_data_dict,\n",
|
48 |
-
" max_ncells=1000,\n",
|
49 |
-
" emb_layer=0,\n",
|
50 |
-
" summary_stat=\"exact_mean\",\n",
|
51 |
-
" forward_batch_size=256,\n",
|
52 |
-
" nproc=16)\n",
|
53 |
-
"\n",
|
54 |
-
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
55 |
-
" \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
|
56 |
-
" \"path/to/input_data\",\n",
|
57 |
-
" \"path/to/output_directory\",\n",
|
58 |
-
" \"output_prefix\")"
|
59 |
]
|
60 |
},
|
61 |
{
|
62 |
"cell_type": "code",
|
63 |
"execution_count": null,
|
64 |
-
"id": "
|
65 |
"metadata": {
|
66 |
"tags": []
|
67 |
},
|
68 |
"outputs": [],
|
69 |
"source": [
|
70 |
-
"#
|
71 |
-
"#
|
72 |
-
"#
|
73 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
74 |
" perturb_rank_shift=None,\n",
|
75 |
" genes_to_perturb=\"all\",\n",
|
76 |
" combos=0,\n",
|
77 |
" anchor_gene=None,\n",
|
78 |
-
" model_type=\"CellClassifier\"
|
79 |
" num_classes=3,\n",
|
80 |
" emb_mode=\"cell\",\n",
|
81 |
" cell_emb_style=\"mean_pool\",\n",
|
82 |
-
" filter_data=
|
83 |
-
" cell_states_to_model=
|
84 |
-
"
|
|
|
|
|
85 |
" max_ncells=2000,\n",
|
86 |
" emb_layer=0,\n",
|
87 |
" forward_batch_size=400,\n",
|
@@ -96,10 +51,9 @@
|
|
96 |
"outputs": [],
|
97 |
"source": [
|
98 |
"# outputs intermediate files from in silico perturbation\n",
|
99 |
-
"\n",
|
100 |
-
"isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
|
101 |
" \"path/to/input_data\",\n",
|
102 |
-
" \"path/to/
|
103 |
" \"output_prefix\")"
|
104 |
]
|
105 |
},
|
@@ -110,14 +64,11 @@
|
|
110 |
"metadata": {},
|
111 |
"outputs": [],
|
112 |
"source": [
|
113 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
114 |
-
"# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
|
115 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
116 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
117 |
" genes_perturbed=\"all\",\n",
|
118 |
" combos=0,\n",
|
119 |
" anchor_gene=None,\n",
|
120 |
-
" cell_states_to_model=
|
121 |
]
|
122 |
},
|
123 |
{
|
@@ -128,9 +79,9 @@
|
|
128 |
"outputs": [],
|
129 |
"source": [
|
130 |
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
131 |
-
"ispstats.get_stats(\"path/to/
|
132 |
" None,\n",
|
133 |
-
" \"path/to/
|
134 |
" \"output_prefix\")"
|
135 |
]
|
136 |
}
|
@@ -151,7 +102,7 @@
|
|
151 |
"name": "python",
|
152 |
"nbconvert_exporter": "python",
|
153 |
"pygments_lexer": "ipython3",
|
154 |
-
"version": "3.10.
|
155 |
}
|
156 |
},
|
157 |
"nbformat": 4,
|
|
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
"from geneformer import InSilicoPerturber\n",
|
11 |
+
"from geneformer import InSilicoPerturberStats"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
]
|
13 |
},
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": null,
|
17 |
+
"id": "67b44366-f255-4415-a865-6a27a8ffcce7",
|
18 |
"metadata": {
|
19 |
"tags": []
|
20 |
},
|
21 |
"outputs": [],
|
22 |
"source": [
|
23 |
+
"# in silico perturbation in deletion mode to determine genes whose \n",
|
24 |
+
"# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
|
25 |
+
"# the embedding towards non-failing (nf) state\n",
|
26 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
27 |
" perturb_rank_shift=None,\n",
|
28 |
" genes_to_perturb=\"all\",\n",
|
29 |
" combos=0,\n",
|
30 |
" anchor_gene=None,\n",
|
31 |
+
" model_type=\"CellClassifier\",\n",
|
32 |
" num_classes=3,\n",
|
33 |
" emb_mode=\"cell\",\n",
|
34 |
" cell_emb_style=\"mean_pool\",\n",
|
35 |
+
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
36 |
+
" cell_states_to_model={'state_key': 'disease', \n",
|
37 |
+
" 'start_state': 'dcm', \n",
|
38 |
+
" 'goal_state': 'nf', \n",
|
39 |
+
" 'alt_states': ['hcm']},\n",
|
40 |
" max_ncells=2000,\n",
|
41 |
" emb_layer=0,\n",
|
42 |
" forward_batch_size=400,\n",
|
|
|
51 |
"outputs": [],
|
52 |
"source": [
|
53 |
"# outputs intermediate files from in silico perturbation\n",
|
54 |
+
"isp.perturb_data(\"path/to/model\",\n",
|
|
|
55 |
" \"path/to/input_data\",\n",
|
56 |
+
" \"path/to/output_directory\",\n",
|
57 |
" \"output_prefix\")"
|
58 |
]
|
59 |
},
|
|
|
64 |
"metadata": {},
|
65 |
"outputs": [],
|
66 |
"source": [
|
|
|
|
|
|
|
67 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
68 |
" genes_perturbed=\"all\",\n",
|
69 |
" combos=0,\n",
|
70 |
" anchor_gene=None,\n",
|
71 |
+
" cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
|
72 |
]
|
73 |
},
|
74 |
{
|
|
|
79 |
"outputs": [],
|
80 |
"source": [
|
81 |
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
82 |
+
"ispstats.get_stats(\"path/to/input_data\",\n",
|
83 |
" None,\n",
|
84 |
+
" \"path/to/output_directory\",\n",
|
85 |
" \"output_prefix\")"
|
86 |
]
|
87 |
}
|
|
|
102 |
"name": "python",
|
103 |
"nbconvert_exporter": "python",
|
104 |
"pygments_lexer": "ipython3",
|
105 |
+
"version": "3.10.11"
|
106 |
}
|
107 |
},
|
108 |
"nbformat": 4,
|
examples/multitask_cell_classification.ipynb
DELETED
@@ -1,420 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "866f100c-e11a-4e7b-a37c-831775d845a7",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# Geneformer Multi-Task Cell Classifier Tutorial\n",
|
9 |
-
"\n",
|
10 |
-
"This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "markdown",
|
15 |
-
"id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
|
16 |
-
"metadata": {},
|
17 |
-
"source": [
|
18 |
-
"## 1. Installation and Imports\n",
|
19 |
-
"\n",
|
20 |
-
"First import the necessary modules."
|
21 |
-
]
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"cell_type": "code",
|
25 |
-
"execution_count": 3,
|
26 |
-
"id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
|
27 |
-
"metadata": {},
|
28 |
-
"outputs": [],
|
29 |
-
"source": [
|
30 |
-
"from geneformer import MTLClassifier"
|
31 |
-
]
|
32 |
-
},
|
33 |
-
{
|
34 |
-
"cell_type": "markdown",
|
35 |
-
"id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
|
36 |
-
"metadata": {},
|
37 |
-
"source": [
|
38 |
-
"## 2. Set up Paths and Parameters\n",
|
39 |
-
"\n",
|
40 |
-
"Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
|
41 |
-
]
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": null,
|
46 |
-
"id": "04a04197-8e45-47f8-a86f-202209ea10ae",
|
47 |
-
"metadata": {},
|
48 |
-
"outputs": [],
|
49 |
-
"source": [
|
50 |
-
"# Define paths\n",
|
51 |
-
"pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
|
52 |
-
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
53 |
-
"train_path = \"/path/to/train/data.dataset\"\n",
|
54 |
-
"val_path = \"/path/to/val/data.dataset\"\n",
|
55 |
-
"test_path = \"/path/to/test/data.dataset\"\n",
|
56 |
-
"results_dir = \"/path/to/results/directory\"\n",
|
57 |
-
"model_save_path = \"/path/to/model/save/path\"\n",
|
58 |
-
"tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
|
59 |
-
"\n",
|
60 |
-
"# Define tasks and hyperparameters\n",
|
61 |
-
"# task_columns should be a list of column names from your dataset\n",
|
62 |
-
"# Each column represents a specific classification task (e.g. cell type, disease state)\n",
|
63 |
-
"task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
|
64 |
-
"\n",
|
65 |
-
"hyperparameters = {\n",
|
66 |
-
" \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
|
67 |
-
" \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
|
68 |
-
" \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
|
69 |
-
" \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
|
70 |
-
" \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
|
71 |
-
" \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
|
72 |
-
"}"
|
73 |
-
]
|
74 |
-
},
|
75 |
-
{
|
76 |
-
"cell_type": "markdown",
|
77 |
-
"id": "31857690-a739-435a-aefd-f171fafc1b78",
|
78 |
-
"metadata": {},
|
79 |
-
"source": [
|
80 |
-
"In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
|
81 |
-
"1. Identifying the cell type\n",
|
82 |
-
"2. Determining the disease state\n",
|
83 |
-
"3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
|
84 |
-
"\n",
|
85 |
-
"These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
|
86 |
-
"\n",
|
87 |
-
"For example, your dataset might look something like this:\n",
|
88 |
-
"\n",
|
89 |
-
" | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
|
90 |
-
" |----------------|-----------|-----|-----------|---------------|\n",
|
91 |
-
" | cell1 | ... | ... | neuron | healthy |\n",
|
92 |
-
" | cell2 | ... | ... | astrocyte | diseased |\n",
|
93 |
-
" | ... | ... | ... | ... | ... |\n",
|
94 |
-
"The model will learn to predict classes within 'cell_type' and 'disease_state' "
|
95 |
-
]
|
96 |
-
},
|
97 |
-
{
|
98 |
-
"cell_type": "markdown",
|
99 |
-
"id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
|
100 |
-
"metadata": {},
|
101 |
-
"source": [
|
102 |
-
"## 3. Initialize the MTLClassifier\n",
|
103 |
-
"\n",
|
104 |
-
"Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
|
105 |
-
]
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"execution_count": null,
|
110 |
-
"id": "e27caac9-670c-409d-9313-50201c665cb9",
|
111 |
-
"metadata": {},
|
112 |
-
"outputs": [],
|
113 |
-
"source": [
|
114 |
-
"mc = MTLClassifier(\n",
|
115 |
-
" task_columns=task_columns, # Our defined classification tasks\n",
|
116 |
-
" study_name=\"MTLClassifier_example\",\n",
|
117 |
-
" pretrained_path=pretrained_path,\n",
|
118 |
-
" train_path=train_path,\n",
|
119 |
-
" val_path=val_path,\n",
|
120 |
-
" test_path=test_path,\n",
|
121 |
-
" model_save_path=model_save_path,\n",
|
122 |
-
" results_dir=results_dir,\n",
|
123 |
-
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
124 |
-
" hyperparameters=hyperparameters,\n",
|
125 |
-
" n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
|
126 |
-
" epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
|
127 |
-
" batch_size=8, # Adjust based on available GPU memory\n",
|
128 |
-
" seed=42\n",
|
129 |
-
")"
|
130 |
-
]
|
131 |
-
},
|
132 |
-
{
|
133 |
-
"cell_type": "markdown",
|
134 |
-
"id": "0d729444-e3ad-4584-9659-0c464ac97462",
|
135 |
-
"metadata": {},
|
136 |
-
"source": [
|
137 |
-
"## 4. Run Hyperparameter Optimization\n",
|
138 |
-
"\n",
|
139 |
-
"Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
|
140 |
-
]
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "code",
|
144 |
-
"execution_count": null,
|
145 |
-
"id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
|
146 |
-
"metadata": {},
|
147 |
-
"outputs": [],
|
148 |
-
"source": [
|
149 |
-
"mc.run_optuna_study()"
|
150 |
-
]
|
151 |
-
},
|
152 |
-
{
|
153 |
-
"cell_type": "markdown",
|
154 |
-
"id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
|
155 |
-
"metadata": {},
|
156 |
-
"source": [
|
157 |
-
"## 5. Evaluate the Model on Test Data\n",
|
158 |
-
"\n",
|
159 |
-
"After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
|
160 |
-
]
|
161 |
-
},
|
162 |
-
{
|
163 |
-
"cell_type": "code",
|
164 |
-
"execution_count": null,
|
165 |
-
"id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
|
166 |
-
"metadata": {},
|
167 |
-
"outputs": [],
|
168 |
-
"source": [
|
169 |
-
"mc.load_and_evaluate_test_model()"
|
170 |
-
]
|
171 |
-
},
|
172 |
-
{
|
173 |
-
"cell_type": "markdown",
|
174 |
-
"id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
|
175 |
-
"metadata": {},
|
176 |
-
"source": [
|
177 |
-
"## 6. (Optional) Manual Hyperparameter Tuning\n",
|
178 |
-
"\n",
|
179 |
-
"If you prefer to set hyperparameters manually, you can use the following approach:"
|
180 |
-
]
|
181 |
-
},
|
182 |
-
{
|
183 |
-
"cell_type": "code",
|
184 |
-
"execution_count": null,
|
185 |
-
"id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
|
186 |
-
"metadata": {},
|
187 |
-
"outputs": [],
|
188 |
-
"source": [
|
189 |
-
"manual_hyperparameters = {\n",
|
190 |
-
" \"learning_rate\": 0.001,\n",
|
191 |
-
" \"warmup_ratio\": 0.01,\n",
|
192 |
-
" \"weight_decay\": 0.1,\n",
|
193 |
-
" \"dropout_rate\": 0.1,\n",
|
194 |
-
" \"lr_scheduler_type\": \"cosine\",\n",
|
195 |
-
" \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
|
196 |
-
" \"max_layers_to_freeze\": 2\n",
|
197 |
-
"}\n",
|
198 |
-
"\n",
|
199 |
-
"mc_manual = MTLClassifier(\n",
|
200 |
-
" task_columns=task_columns,\n",
|
201 |
-
" study_name=\"mtl_manual\",\n",
|
202 |
-
" pretrained_path=pretrained_path,\n",
|
203 |
-
" train_path=train_path,\n",
|
204 |
-
" val_path=val_path,\n",
|
205 |
-
" test_path=test_path,\n",
|
206 |
-
" model_save_path=model_save_path,\n",
|
207 |
-
" results_dir=results_dir,\n",
|
208 |
-
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
209 |
-
" manual_hyperparameters=manual_hyperparameters,\n",
|
210 |
-
" use_manual_hyperparameters=True,\n",
|
211 |
-
" epochs=10,\n",
|
212 |
-
" batch_size=32,\n",
|
213 |
-
" seed=42\n",
|
214 |
-
")\n",
|
215 |
-
"\n",
|
216 |
-
"mc_manual.run_manual_tuning()"
|
217 |
-
]
|
218 |
-
},
|
219 |
-
{
|
220 |
-
"cell_type": "markdown",
|
221 |
-
"id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
|
222 |
-
"metadata": {},
|
223 |
-
"source": [
|
224 |
-
"# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
|
225 |
-
"This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
|
226 |
-
]
|
227 |
-
},
|
228 |
-
{
|
229 |
-
"cell_type": "code",
|
230 |
-
"execution_count": null,
|
231 |
-
"id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
|
232 |
-
"metadata": {},
|
233 |
-
"outputs": [],
|
234 |
-
"source": [
|
235 |
-
"from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
|
236 |
-
]
|
237 |
-
},
|
238 |
-
{
|
239 |
-
"cell_type": "code",
|
240 |
-
"execution_count": null,
|
241 |
-
"id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
|
242 |
-
"metadata": {},
|
243 |
-
"outputs": [],
|
244 |
-
"source": [
|
245 |
-
"# Define paths\n",
|
246 |
-
"model_directory = \"/path/to/model/save/path\"\n",
|
247 |
-
"input_data_file = \"/path/to/input/data.dataset\"\n",
|
248 |
-
"output_directory = \"/path/to/output/directory\"\n",
|
249 |
-
"output_prefix = \"mtl_quantized_perturbation\"\n",
|
250 |
-
"\n",
|
251 |
-
"# Define parameters\n",
|
252 |
-
"perturb_type = \"delete\" # or \"overexpress\"\n",
|
253 |
-
"\n",
|
254 |
-
"# Define cell states to model\n",
|
255 |
-
"cell_states_to_model = {\n",
|
256 |
-
" \"state_key\": \"disease_state\", \n",
|
257 |
-
" \"start_state\": \"disease\", \n",
|
258 |
-
" \"goal_state\": \"control\"\n",
|
259 |
-
"}\n",
|
260 |
-
"\n",
|
261 |
-
"# Define filter data\n",
|
262 |
-
"filter_data_dict = {\n",
|
263 |
-
" \"cell_type\": [\"Fibroblast\"]\n",
|
264 |
-
"}"
|
265 |
-
]
|
266 |
-
},
|
267 |
-
{
|
268 |
-
"cell_type": "markdown",
|
269 |
-
"id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
|
270 |
-
"metadata": {},
|
271 |
-
"source": [
|
272 |
-
"## 3. Extract State Embeddings\n",
|
273 |
-
"\n",
|
274 |
-
"Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
|
275 |
-
]
|
276 |
-
},
|
277 |
-
{
|
278 |
-
"cell_type": "code",
|
279 |
-
"execution_count": null,
|
280 |
-
"id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
|
281 |
-
"metadata": {},
|
282 |
-
"outputs": [],
|
283 |
-
"source": [
|
284 |
-
"# Initialize EmbExtractor\n",
|
285 |
-
"embex = EmbExtractor(\n",
|
286 |
-
" filter_data_dict=filter_data_dict,\n",
|
287 |
-
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
288 |
-
" emb_layer=0, # Use the second to last layer\n",
|
289 |
-
" emb_mode = \"cls\",\n",
|
290 |
-
" summary_stat=\"exact_mean\",\n",
|
291 |
-
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
292 |
-
" nproc=4\n",
|
293 |
-
")\n",
|
294 |
-
"\n",
|
295 |
-
"# Extract state embeddings\n",
|
296 |
-
"state_embs_dict = embex.get_state_embs(\n",
|
297 |
-
" cell_states_to_model,\n",
|
298 |
-
" model_directory=model_directory,\n",
|
299 |
-
" input_data_file=input_data_file,\n",
|
300 |
-
" output_directory=output_directory,\n",
|
301 |
-
" output_prefix=output_prefix\n",
|
302 |
-
")"
|
303 |
-
]
|
304 |
-
},
|
305 |
-
{
|
306 |
-
"cell_type": "markdown",
|
307 |
-
"id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
|
308 |
-
"metadata": {},
|
309 |
-
"source": [
|
310 |
-
"## 4. Initialize the InSilicoPerturber\n",
|
311 |
-
"\n",
|
312 |
-
"Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
|
313 |
-
]
|
314 |
-
},
|
315 |
-
{
|
316 |
-
"cell_type": "code",
|
317 |
-
"execution_count": null,
|
318 |
-
"id": "09f985a1-91bc-4e8d-8001-a3663531b570",
|
319 |
-
"metadata": {},
|
320 |
-
"outputs": [],
|
321 |
-
"source": [
|
322 |
-
"# Initialize InSilicoPerturber\n",
|
323 |
-
"isp = InSilicoPerturber(\n",
|
324 |
-
" perturb_type=perturb_type,\n",
|
325 |
-
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
326 |
-
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
327 |
-
" emb_mode=\"cls\", # Use CLS token embedding\n",
|
328 |
-
" cell_states_to_model=cell_states_to_model,\n",
|
329 |
-
" state_embs_dict=state_embs_dict,\n",
|
330 |
-
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
331 |
-
" emb_layer=0, \n",
|
332 |
-
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
333 |
-
" nproc=1\n",
|
334 |
-
")"
|
335 |
-
]
|
336 |
-
},
|
337 |
-
{
|
338 |
-
"cell_type": "markdown",
|
339 |
-
"id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
|
340 |
-
"metadata": {},
|
341 |
-
"source": [
|
342 |
-
"## 5. Run In Silico Perturbation\n",
|
343 |
-
"\n",
|
344 |
-
"Run the in silico perturbation on the dataset."
|
345 |
-
]
|
346 |
-
},
|
347 |
-
{
|
348 |
-
"cell_type": "code",
|
349 |
-
"execution_count": null,
|
350 |
-
"id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
|
351 |
-
"metadata": {},
|
352 |
-
"outputs": [],
|
353 |
-
"source": [
|
354 |
-
"# Run perturbation and output intermediate files\n",
|
355 |
-
"isp.perturb_data(\n",
|
356 |
-
" model_directory=model_directory,\n",
|
357 |
-
" input_data_file=input_data_file,\n",
|
358 |
-
" output_directory=output_directory,\n",
|
359 |
-
" output_prefix=output_prefix\n",
|
360 |
-
")"
|
361 |
-
]
|
362 |
-
},
|
363 |
-
{
|
364 |
-
"cell_type": "markdown",
|
365 |
-
"id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
|
366 |
-
"metadata": {},
|
367 |
-
"source": [
|
368 |
-
"## 6. Process Results with InSilicoPerturberStats\n",
|
369 |
-
"\n",
|
370 |
-
"After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
|
371 |
-
]
|
372 |
-
},
|
373 |
-
{
|
374 |
-
"cell_type": "code",
|
375 |
-
"execution_count": null,
|
376 |
-
"id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
|
377 |
-
"metadata": {},
|
378 |
-
"outputs": [],
|
379 |
-
"source": [
|
380 |
-
"# Initialize InSilicoPerturberStats\n",
|
381 |
-
"ispstats = InSilicoPerturberStats(\n",
|
382 |
-
" mode=\"goal_state_shift\",\n",
|
383 |
-
" genes_perturbed=\"all\",\n",
|
384 |
-
" combos=0,\n",
|
385 |
-
" anchor_gene=None,\n",
|
386 |
-
" cell_states_to_model=cell_states_to_model\n",
|
387 |
-
")\n",
|
388 |
-
"\n",
|
389 |
-
"# Process stats and output final .csv\n",
|
390 |
-
"ispstats.get_stats(\n",
|
391 |
-
" input_data_file,\n",
|
392 |
-
" None,\n",
|
393 |
-
" output_directory,\n",
|
394 |
-
" output_prefix\n",
|
395 |
-
")"
|
396 |
-
]
|
397 |
-
}
|
398 |
-
],
|
399 |
-
"metadata": {
|
400 |
-
"kernelspec": {
|
401 |
-
"display_name": "Python 3 (ipykernel)",
|
402 |
-
"language": "python",
|
403 |
-
"name": "python3"
|
404 |
-
},
|
405 |
-
"language_info": {
|
406 |
-
"codemirror_mode": {
|
407 |
-
"name": "ipython",
|
408 |
-
"version": 3
|
409 |
-
},
|
410 |
-
"file_extension": ".py",
|
411 |
-
"mimetype": "text/x-python",
|
412 |
-
"name": "python",
|
413 |
-
"nbconvert_exporter": "python",
|
414 |
-
"pygments_lexer": "ipython3",
|
415 |
-
"version": "3.11.5"
|
416 |
-
}
|
417 |
-
},
|
418 |
-
"nbformat": 4,
|
419 |
-
"nbformat_minor": 5
|
420 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py
CHANGED
@@ -138,9 +138,7 @@ training_args = {
|
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
"save_strategy": "steps",
|
141 |
-
"save_steps": np.floor(
|
142 |
-
num_examples / geneformer_batch_size / 8
|
143 |
-
), # 8 saves per epoch
|
144 |
"logging_steps": 1000,
|
145 |
"output_dir": training_output_dir,
|
146 |
"logging_dir": logging_dir,
|
|
|
138 |
"per_device_train_batch_size": geneformer_batch_size,
|
139 |
"num_train_epochs": epochs,
|
140 |
"save_strategy": "steps",
|
141 |
+
"save_steps": np.floor(num_examples / geneformer_batch_size / 8), # 8 saves per epoch
|
|
|
|
|
142 |
"logging_steps": 1000,
|
143 |
"output_dir": training_output_dir,
|
144 |
"logging_dir": logging_dir,
|
examples/tokenizing_scRNAseq_data.ipynb
CHANGED
@@ -7,39 +7,23 @@
|
|
7 |
"tags": []
|
8 |
},
|
9 |
"source": [
|
10 |
-
"## Tokenizing .loom
|
11 |
]
|
12 |
},
|
13 |
{
|
14 |
"cell_type": "markdown",
|
15 |
-
"id": "
|
16 |
"metadata": {},
|
17 |
"source": [
|
18 |
-
"#### Input data is a directory with .loom
|
19 |
"\n",
|
20 |
-
"####
|
21 |
-
"\n",
|
22 |
-
"#### Genes should be labeled with Ensembl IDs (loom row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute \"n_counts\") to be used for normalization.\n",
|
23 |
"\n",
|
24 |
"#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
|
25 |
"\n",
|
26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
27 |
"\n",
|
28 |
-
"#### If one's data is in other formats besides .loom
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "markdown",
|
33 |
-
"id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
|
34 |
-
"metadata": {},
|
35 |
-
"source": [
|
36 |
-
"**********************************************************************************************************\n",
|
37 |
-
"#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
|
38 |
-
"#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
|
39 |
-
"\n",
|
40 |
-
"#### ADDITIONALLY:\n",
|
41 |
-
"#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
|
42 |
-
"#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
|
43 |
]
|
44 |
},
|
45 |
{
|
@@ -59,11 +43,8 @@
|
|
59 |
"metadata": {},
|
60 |
"outputs": [],
|
61 |
"source": [
|
62 |
-
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"
|
63 |
-
"tk.tokenize_data(\"loom_data_directory\", \
|
64 |
-
" \"output_directory\", \n",
|
65 |
-
" \"output_prefix\", \n",
|
66 |
-
" file_format=\"loom\")"
|
67 |
]
|
68 |
}
|
69 |
],
|
@@ -83,7 +64,7 @@
|
|
83 |
"name": "python",
|
84 |
"nbconvert_exporter": "python",
|
85 |
"pygments_lexer": "ipython3",
|
86 |
-
"version": "3.10.
|
87 |
}
|
88 |
},
|
89 |
"nbformat": 4,
|
|
|
7 |
"tags": []
|
8 |
},
|
9 |
"source": [
|
10 |
+
"## Tokenizing .loom single cell RNA-seq data to rank value encoding .dataset format"
|
11 |
]
|
12 |
},
|
13 |
{
|
14 |
"cell_type": "markdown",
|
15 |
+
"id": "350e6252-b783-494b-9767-f087eb868a15",
|
16 |
"metadata": {},
|
17 |
"source": [
|
18 |
+
"#### Input data is a directory with .loom files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. \n",
|
19 |
"\n",
|
20 |
+
"#### Genes should be labeled with Ensembl IDs (row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (column attribute \"n_counts\") to be used for normalization.\n",
|
|
|
|
|
21 |
"\n",
|
22 |
"#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
|
23 |
"\n",
|
24 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
25 |
"\n",
|
26 |
+
"#### If one's data is in other formats besides .loom, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom format prior to running the transcriptome tokenizer."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
]
|
28 |
},
|
29 |
{
|
|
|
43 |
"metadata": {},
|
44 |
"outputs": [],
|
45 |
"source": [
|
46 |
+
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
|
47 |
+
"tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\")"
|
|
|
|
|
|
|
48 |
]
|
49 |
}
|
50 |
],
|
|
|
64 |
"name": "python",
|
65 |
"nbconvert_exporter": "python",
|
66 |
"pygments_lexer": "ipython3",
|
67 |
+
"version": "3.10.11"
|
68 |
}
|
69 |
},
|
70 |
"nbformat": 4,
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/config.json
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/optimizer.pt
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/rng_state.pth
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/scheduler.pt
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/trainer_state.json
RENAMED
File without changes
|
fine_tuned_models/{gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224 → geneformer-6L-30M_CellClassifier_cardiomyopathies_220224}/training_args.bin
RENAMED
File without changes
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"classifier_dropout": null,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 512,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 1024,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 4096,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 8,
|
16 |
-
"num_hidden_layers": 12,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"torch_dtype": "float32",
|
20 |
-
"transformers_version": "4.37.2",
|
21 |
-
"type_vocab_size": 2,
|
22 |
-
"use_cache": true,
|
23 |
-
"vocab_size": 20275
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
|
3 |
-
size 152363342
|
|
|
|
|
|
|
|
{gf-12L-30M-i2048 → geneformer-12L-30M}/config.json
RENAMED
File without changes
|
{gf-12L-30M-i2048 → geneformer-12L-30M}/pytorch_model.bin
RENAMED
File without changes
|
{gf-12L-30M-i2048 → geneformer-12L-30M}/training_args.bin
RENAMED
File without changes
|
geneformer/__init__.py
CHANGED
@@ -1,34 +1,12 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
|
8 |
-
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
|
9 |
-
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
|
10 |
-
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
|
11 |
-
|
12 |
-
from . import (
|
13 |
-
collator_for_classification,
|
14 |
-
emb_extractor,
|
15 |
-
in_silico_perturber,
|
16 |
-
in_silico_perturber_stats,
|
17 |
-
pretrainer,
|
18 |
-
tokenizer,
|
19 |
-
)
|
20 |
-
from .collator_for_classification import (
|
21 |
-
DataCollatorForCellClassification,
|
22 |
-
DataCollatorForGeneClassification,
|
23 |
-
)
|
24 |
-
from .emb_extractor import EmbExtractor, get_embs
|
25 |
-
from .in_silico_perturber import InSilicoPerturber
|
26 |
-
from .in_silico_perturber_stats import InSilicoPerturberStats
|
27 |
-
from .pretrainer import GeneformerPretrainer
|
28 |
from .tokenizer import TranscriptomeTokenizer
|
29 |
-
|
30 |
-
from . import
|
31 |
-
from .
|
32 |
-
|
33 |
-
from . import
|
34 |
-
from .
|
|
|
1 |
+
from . import tokenizer
|
2 |
+
from . import pretrainer
|
3 |
+
from . import collator_for_classification
|
4 |
+
from . import in_silico_perturber
|
5 |
+
from . import in_silico_perturber_stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from .tokenizer import TranscriptomeTokenizer
|
7 |
+
from .pretrainer import GeneformerPretrainer
|
8 |
+
from .collator_for_classification import DataCollatorForGeneClassification
|
9 |
+
from .collator_for_classification import DataCollatorForCellClassification
|
10 |
+
from .emb_extractor import EmbExtractor
|
11 |
+
from .in_silico_perturber import InSilicoPerturber
|
12 |
+
from .in_silico_perturber_stats import InSilicoPerturberStats
|
geneformer/classifier.py
DELETED
@@ -1,1563 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Geneformer classifier.
|
3 |
-
|
4 |
-
**Input data:**
|
5 |
-
|
6 |
-
| Cell state classifier:
|
7 |
-
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
|
8 |
-
|
9 |
-
| Gene classifier:
|
10 |
-
| Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
|
11 |
-
|
12 |
-
**Usage:**
|
13 |
-
|
14 |
-
.. code-block :: python
|
15 |
-
|
16 |
-
>>> from geneformer import Classifier
|
17 |
-
>>> cc = Classifier(classifier="cell", # example of cell state classifier
|
18 |
-
... cell_state_dict={"state_key": "disease", "states": "all"},
|
19 |
-
... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
|
20 |
-
... training_args=training_args,
|
21 |
-
... freeze_layers = 2,
|
22 |
-
... num_crossval_splits = 1,
|
23 |
-
... forward_batch_size=200,
|
24 |
-
... nproc=16)
|
25 |
-
>>> cc.prepare_data(input_data_file="path/to/input_data",
|
26 |
-
... output_directory="path/to/output_directory",
|
27 |
-
... output_prefix="output_prefix")
|
28 |
-
>>> all_metrics = cc.validate(model_directory="path/to/model",
|
29 |
-
... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
|
30 |
-
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
31 |
-
... output_directory="path/to/output_directory",
|
32 |
-
... output_prefix="output_prefix",
|
33 |
-
... predict_eval=True)
|
34 |
-
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
35 |
-
... output_directory="path/to/output_directory",
|
36 |
-
... output_prefix="output_prefix",
|
37 |
-
... custom_class_order=["healthy","disease1","disease2"])
|
38 |
-
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
|
39 |
-
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
40 |
-
... title="disease",
|
41 |
-
... output_directory="path/to/output_directory",
|
42 |
-
... output_prefix="output_prefix",
|
43 |
-
... custom_class_order=["healthy","disease1","disease2"])
|
44 |
-
"""
|
45 |
-
|
46 |
-
import datetime
|
47 |
-
import logging
|
48 |
-
import os
|
49 |
-
import pickle
|
50 |
-
import subprocess
|
51 |
-
from pathlib import Path
|
52 |
-
|
53 |
-
import numpy as np
|
54 |
-
import pandas as pd
|
55 |
-
import seaborn as sns
|
56 |
-
from tqdm.auto import tqdm, trange
|
57 |
-
from transformers import Trainer
|
58 |
-
from transformers.training_args import TrainingArguments
|
59 |
-
|
60 |
-
from . import (
|
61 |
-
TOKEN_DICTIONARY_FILE,
|
62 |
-
DataCollatorForCellClassification,
|
63 |
-
DataCollatorForGeneClassification,
|
64 |
-
)
|
65 |
-
from . import classifier_utils as cu
|
66 |
-
from . import evaluation_utils as eu
|
67 |
-
from . import perturber_utils as pu
|
68 |
-
|
69 |
-
sns.set()
|
70 |
-
|
71 |
-
|
72 |
-
logger = logging.getLogger(__name__)
|
73 |
-
|
74 |
-
|
75 |
-
class Classifier:
|
76 |
-
valid_option_dict = {
|
77 |
-
"classifier": {"cell", "gene"},
|
78 |
-
"quantize": {bool, dict},
|
79 |
-
"cell_state_dict": {None, dict},
|
80 |
-
"gene_class_dict": {None, dict},
|
81 |
-
"filter_data": {None, dict},
|
82 |
-
"rare_threshold": {int, float},
|
83 |
-
"max_ncells": {None, int},
|
84 |
-
"max_ncells_per_class": {None, int},
|
85 |
-
"training_args": {None, dict},
|
86 |
-
"freeze_layers": {int},
|
87 |
-
"num_crossval_splits": {0, 1, 5},
|
88 |
-
"split_sizes": {None, dict},
|
89 |
-
"no_eval": {bool},
|
90 |
-
"stratify_splits_col": {None, str},
|
91 |
-
"forward_batch_size": {int},
|
92 |
-
"token_dictionary_file": {None, str},
|
93 |
-
"nproc": {int},
|
94 |
-
"ngpu": {int},
|
95 |
-
}
|
96 |
-
|
97 |
-
def __init__(
|
98 |
-
self,
|
99 |
-
classifier=None,
|
100 |
-
quantize=False,
|
101 |
-
cell_state_dict=None,
|
102 |
-
gene_class_dict=None,
|
103 |
-
filter_data=None,
|
104 |
-
rare_threshold=0,
|
105 |
-
max_ncells=None,
|
106 |
-
max_ncells_per_class=None,
|
107 |
-
training_args=None,
|
108 |
-
ray_config=None,
|
109 |
-
freeze_layers=0,
|
110 |
-
num_crossval_splits=1,
|
111 |
-
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
|
112 |
-
stratify_splits_col=None,
|
113 |
-
no_eval=False,
|
114 |
-
forward_batch_size=100,
|
115 |
-
token_dictionary_file=None,
|
116 |
-
nproc=4,
|
117 |
-
ngpu=1,
|
118 |
-
):
|
119 |
-
"""
|
120 |
-
Initialize Geneformer classifier.
|
121 |
-
|
122 |
-
**Parameters:**
|
123 |
-
|
124 |
-
classifier : {"cell", "gene"}
|
125 |
-
| Whether to fine-tune a cell state or gene classifier.
|
126 |
-
quantize : bool, dict
|
127 |
-
| Whether to fine-tune a quantized model.
|
128 |
-
| If True and no config provided, will use default.
|
129 |
-
| Will use custom config if provided.
|
130 |
-
| Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
|
131 |
-
| For example: {"bnb_config": BitsAndBytesConfig(...),
|
132 |
-
| "peft_config": LoraConfig(...)}
|
133 |
-
cell_state_dict : None, dict
|
134 |
-
| Cell states to fine-tune model to distinguish.
|
135 |
-
| Two-item dictionary with keys: state_key and states
|
136 |
-
| state_key: key specifying name of column in .dataset that defines the states to model
|
137 |
-
| states: list of values in the state_key column that specifies the states to model
|
138 |
-
| Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
|
139 |
-
| Of note, if using "all", states will be defined after data is filtered.
|
140 |
-
| Must have at least 2 states to model.
|
141 |
-
| For example: {"state_key": "disease",
|
142 |
-
| "states": ["nf", "hcm", "dcm"]}
|
143 |
-
| or
|
144 |
-
| {"state_key": "disease",
|
145 |
-
| "states": "all"}
|
146 |
-
gene_class_dict : None, dict
|
147 |
-
| Gene classes to fine-tune model to distinguish.
|
148 |
-
| Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
|
149 |
-
| Gene_label_B: list(geneB1, geneB2, ...)}
|
150 |
-
| Gene values should be Ensembl IDs.
|
151 |
-
filter_data : None, dict
|
152 |
-
| Default is to fine-tune with all input data.
|
153 |
-
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
154 |
-
rare_threshold : float
|
155 |
-
| Threshold below which rare cell states should be removed.
|
156 |
-
| For example, setting to 0.05 will remove cell states representing
|
157 |
-
| < 5% of the total cells from the cell state classifier's possible classes.
|
158 |
-
max_ncells : None, int
|
159 |
-
| Maximum number of cells to use for fine-tuning.
|
160 |
-
| Default is to fine-tune with all input data.
|
161 |
-
max_ncells_per_class : None, int
|
162 |
-
| Maximum number of cells per cell class to use for fine-tuning.
|
163 |
-
| Of note, will be applied after max_ncells above.
|
164 |
-
| (Only valid for cell classification.)
|
165 |
-
training_args : None, dict
|
166 |
-
| Training arguments for fine-tuning.
|
167 |
-
| If None, defaults will be inferred for 6 layer Geneformer.
|
168 |
-
| Otherwise, will use the Hugging Face defaults:
|
169 |
-
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
170 |
-
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
171 |
-
ray_config : None, dict
|
172 |
-
| Training argument ranges for tuning hyperparameters with Ray.
|
173 |
-
freeze_layers : int
|
174 |
-
| Number of layers to freeze from fine-tuning.
|
175 |
-
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
176 |
-
num_crossval_splits : {0, 1, 5}
|
177 |
-
| 0: train on all data without splitting
|
178 |
-
| 1: split data into train and eval sets by designated split_sizes["valid"]
|
179 |
-
| 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
|
180 |
-
split_sizes : None, dict
|
181 |
-
| Dictionary of proportion of data to hold out for train, validation, and test sets
|
182 |
-
| {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
|
183 |
-
stratify_splits_col : None, str
|
184 |
-
| Name of column in .dataset to be used for stratified splitting.
|
185 |
-
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
186 |
-
no_eval : bool
|
187 |
-
| If True, will skip eval step and use all data for training.
|
188 |
-
| Otherwise, will perform eval during training.
|
189 |
-
forward_batch_size : int
|
190 |
-
| Batch size for forward pass (for evaluation, not training).
|
191 |
-
token_dictionary_file : None, str
|
192 |
-
| Default is to use token dictionary file from Geneformer
|
193 |
-
| Otherwise, will load custom gene token dictionary.
|
194 |
-
nproc : int
|
195 |
-
| Number of CPU processes to use.
|
196 |
-
ngpu : int
|
197 |
-
| Number of GPUs available.
|
198 |
-
|
199 |
-
"""
|
200 |
-
|
201 |
-
self.classifier = classifier
|
202 |
-
if self.classifier == "cell":
|
203 |
-
self.model_type = "CellClassifier"
|
204 |
-
elif self.classifier == "gene":
|
205 |
-
self.model_type = "GeneClassifier"
|
206 |
-
self.quantize = quantize
|
207 |
-
self.cell_state_dict = cell_state_dict
|
208 |
-
self.gene_class_dict = gene_class_dict
|
209 |
-
self.filter_data = filter_data
|
210 |
-
self.rare_threshold = rare_threshold
|
211 |
-
self.max_ncells = max_ncells
|
212 |
-
self.max_ncells_per_class = max_ncells_per_class
|
213 |
-
self.training_args = training_args
|
214 |
-
self.ray_config = ray_config
|
215 |
-
self.freeze_layers = freeze_layers
|
216 |
-
self.num_crossval_splits = num_crossval_splits
|
217 |
-
self.split_sizes = split_sizes
|
218 |
-
self.train_size = self.split_sizes["train"]
|
219 |
-
self.valid_size = self.split_sizes["valid"]
|
220 |
-
self.oos_test_size = self.split_sizes["test"]
|
221 |
-
self.eval_size = self.valid_size / (self.train_size + self.valid_size)
|
222 |
-
self.stratify_splits_col = stratify_splits_col
|
223 |
-
self.no_eval = no_eval
|
224 |
-
self.forward_batch_size = forward_batch_size
|
225 |
-
self.token_dictionary_file = token_dictionary_file
|
226 |
-
self.nproc = nproc
|
227 |
-
self.ngpu = ngpu
|
228 |
-
|
229 |
-
if self.training_args is None:
|
230 |
-
logger.warning(
|
231 |
-
"Hyperparameter tuning is highly recommended for optimal results. "
|
232 |
-
"No training_args provided; using default hyperparameters."
|
233 |
-
)
|
234 |
-
|
235 |
-
self.validate_options()
|
236 |
-
|
237 |
-
if self.filter_data is None:
|
238 |
-
self.filter_data = dict()
|
239 |
-
|
240 |
-
if self.classifier == "cell":
|
241 |
-
if self.cell_state_dict["states"] != "all":
|
242 |
-
self.filter_data[
|
243 |
-
self.cell_state_dict["state_key"]
|
244 |
-
] = self.cell_state_dict["states"]
|
245 |
-
|
246 |
-
# load token dictionary (Ensembl IDs:token)
|
247 |
-
if self.token_dictionary_file is None:
|
248 |
-
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
249 |
-
with open(self.token_dictionary_file, "rb") as f:
|
250 |
-
self.gene_token_dict = pickle.load(f)
|
251 |
-
|
252 |
-
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
253 |
-
|
254 |
-
# filter genes for gene classification for those in token dictionary
|
255 |
-
if self.classifier == "gene":
|
256 |
-
all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
|
257 |
-
missing_genes = [
|
258 |
-
gene
|
259 |
-
for gene in all_gene_class_values
|
260 |
-
if gene not in self.gene_token_dict.keys()
|
261 |
-
]
|
262 |
-
if len(missing_genes) == len(all_gene_class_values):
|
263 |
-
logger.error(
|
264 |
-
"None of the provided genes to classify are in token dictionary."
|
265 |
-
)
|
266 |
-
raise
|
267 |
-
elif len(missing_genes) > 0:
|
268 |
-
logger.warning(
|
269 |
-
f"Genes to classify {missing_genes} are not in token dictionary."
|
270 |
-
)
|
271 |
-
self.gene_class_dict = {
|
272 |
-
k: list(set([self.gene_token_dict.get(gene) for gene in v]))
|
273 |
-
for k, v in self.gene_class_dict.items()
|
274 |
-
}
|
275 |
-
empty_classes = []
|
276 |
-
for k, v in self.gene_class_dict.items():
|
277 |
-
if len(v) == 0:
|
278 |
-
empty_classes += [k]
|
279 |
-
if len(empty_classes) > 0:
|
280 |
-
logger.error(
|
281 |
-
f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
|
282 |
-
)
|
283 |
-
raise
|
284 |
-
|
285 |
-
def validate_options(self):
|
286 |
-
# confirm arguments are within valid options and compatible with each other
|
287 |
-
for attr_name, valid_options in self.valid_option_dict.items():
|
288 |
-
attr_value = self.__dict__[attr_name]
|
289 |
-
if not isinstance(attr_value, (list, dict)):
|
290 |
-
if attr_value in valid_options:
|
291 |
-
continue
|
292 |
-
valid_type = False
|
293 |
-
for option in valid_options:
|
294 |
-
if (option in [int, float, list, dict, bool, str]) and isinstance(
|
295 |
-
attr_value, option
|
296 |
-
):
|
297 |
-
valid_type = True
|
298 |
-
break
|
299 |
-
if valid_type:
|
300 |
-
continue
|
301 |
-
logger.error(
|
302 |
-
f"Invalid option for {attr_name}. "
|
303 |
-
f"Valid options for {attr_name}: {valid_options}"
|
304 |
-
)
|
305 |
-
raise
|
306 |
-
|
307 |
-
if self.filter_data is not None:
|
308 |
-
for key, value in self.filter_data.items():
|
309 |
-
if not isinstance(value, list):
|
310 |
-
self.filter_data[key] = [value]
|
311 |
-
logger.warning(
|
312 |
-
"Values in filter_data dict must be lists. "
|
313 |
-
f"Changing {key} value to list ([{value}])."
|
314 |
-
)
|
315 |
-
|
316 |
-
if self.classifier == "cell":
|
317 |
-
if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
|
318 |
-
logger.error(
|
319 |
-
"Invalid keys for cell_state_dict. "
|
320 |
-
"The cell_state_dict should have only 2 keys: state_key and states"
|
321 |
-
)
|
322 |
-
raise
|
323 |
-
|
324 |
-
if self.cell_state_dict["states"] != "all":
|
325 |
-
if not isinstance(self.cell_state_dict["states"], list):
|
326 |
-
logger.error(
|
327 |
-
"States in cell_state_dict should be list of states to model."
|
328 |
-
)
|
329 |
-
raise
|
330 |
-
if len(self.cell_state_dict["states"]) < 2:
|
331 |
-
logger.error(
|
332 |
-
"States in cell_state_dict should contain at least 2 states to classify."
|
333 |
-
)
|
334 |
-
raise
|
335 |
-
|
336 |
-
if self.classifier == "gene":
|
337 |
-
if len(self.gene_class_dict.keys()) < 2:
|
338 |
-
logger.error(
|
339 |
-
"Gene_class_dict should contain at least 2 gene classes to classify."
|
340 |
-
)
|
341 |
-
raise
|
342 |
-
if sum(self.split_sizes.values()) != 1:
|
343 |
-
logger.error("Train, validation, and test proportions should sum to 1.")
|
344 |
-
raise
|
345 |
-
|
346 |
-
def prepare_data(
|
347 |
-
self,
|
348 |
-
input_data_file,
|
349 |
-
output_directory,
|
350 |
-
output_prefix,
|
351 |
-
split_id_dict=None,
|
352 |
-
test_size=None,
|
353 |
-
attr_to_split=None,
|
354 |
-
attr_to_balance=None,
|
355 |
-
max_trials=100,
|
356 |
-
pval_threshold=0.1,
|
357 |
-
):
|
358 |
-
"""
|
359 |
-
Prepare data for cell state or gene classification.
|
360 |
-
|
361 |
-
**Parameters**
|
362 |
-
|
363 |
-
input_data_file : Path
|
364 |
-
| Path to directory containing .dataset input
|
365 |
-
output_directory : Path
|
366 |
-
| Path to directory where prepared data will be saved
|
367 |
-
output_prefix : str
|
368 |
-
| Prefix for output file
|
369 |
-
split_id_dict : None, dict
|
370 |
-
| Dictionary of IDs for train and test splits
|
371 |
-
| Three-item dictionary with keys: attr_key, train, test
|
372 |
-
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
373 |
-
| train: list of IDs in the attr_key column to include in the train split
|
374 |
-
| test: list of IDs in the attr_key column to include in the test split
|
375 |
-
| For example: {"attr_key": "individual",
|
376 |
-
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
377 |
-
| "test": ["patient5", "patient6"]}
|
378 |
-
test_size : None, float
|
379 |
-
| Proportion of data to be saved separately and held out for test set
|
380 |
-
| (e.g. 0.2 if intending hold out 20%)
|
381 |
-
| If None, will inherit from split_sizes["test"] from Classifier
|
382 |
-
| The training set will be further split to train / validation in self.validate
|
383 |
-
| Note: only available for CellClassifiers
|
384 |
-
attr_to_split : None, str
|
385 |
-
| Key for attribute on which to split data while balancing potential confounders
|
386 |
-
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
387 |
-
| Note: only available for CellClassifiers
|
388 |
-
attr_to_balance : None, list
|
389 |
-
| List of attribute keys on which to balance data while splitting on attr_to_split
|
390 |
-
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
391 |
-
| Note: only available for CellClassifiers
|
392 |
-
max_trials : None, int
|
393 |
-
| Maximum number of trials of random splitting to try to achieve balanced other attributes
|
394 |
-
| If no split is found without significant (p<0.05) differences in other attributes, will select best
|
395 |
-
| Note: only available for CellClassifiers
|
396 |
-
pval_threshold : None, float
|
397 |
-
| P-value threshold to use for attribute balancing across splits
|
398 |
-
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
399 |
-
"""
|
400 |
-
|
401 |
-
if test_size is None:
|
402 |
-
test_size = self.oos_test_size
|
403 |
-
|
404 |
-
# prepare data and labels for classification
|
405 |
-
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
406 |
-
|
407 |
-
if self.classifier == "cell":
|
408 |
-
if "label" in data.features:
|
409 |
-
logger.error(
|
410 |
-
"Column name 'label' must be reserved for class IDs. Please rename column."
|
411 |
-
)
|
412 |
-
raise
|
413 |
-
elif self.classifier == "gene":
|
414 |
-
if "labels" in data.features:
|
415 |
-
logger.error(
|
416 |
-
"Column name 'labels' must be reserved for class IDs. Please rename column."
|
417 |
-
)
|
418 |
-
raise
|
419 |
-
|
420 |
-
if (attr_to_split is not None) and (attr_to_balance is None):
|
421 |
-
logger.error(
|
422 |
-
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
|
423 |
-
)
|
424 |
-
raise
|
425 |
-
|
426 |
-
if not isinstance(attr_to_balance, list):
|
427 |
-
attr_to_balance = [attr_to_balance]
|
428 |
-
|
429 |
-
if self.classifier == "cell":
|
430 |
-
# remove cell states representing < rare_threshold of cells
|
431 |
-
data = cu.remove_rare(
|
432 |
-
data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
|
433 |
-
)
|
434 |
-
# downsample max cells and max per class
|
435 |
-
data = cu.downsample_and_shuffle(
|
436 |
-
data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
|
437 |
-
)
|
438 |
-
# rename cell state column to "label"
|
439 |
-
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
440 |
-
|
441 |
-
# convert classes to numerical labels and save as id_class_dict
|
442 |
-
# of note, will label all genes in gene_class_dict
|
443 |
-
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
444 |
-
# at the time of training with Classifier.validate
|
445 |
-
data, id_class_dict = cu.label_classes(
|
446 |
-
self.classifier, data, self.gene_class_dict, self.nproc
|
447 |
-
)
|
448 |
-
|
449 |
-
# save id_class_dict for future reference
|
450 |
-
id_class_output_path = (
|
451 |
-
Path(output_directory) / f"{output_prefix}_id_class_dict"
|
452 |
-
).with_suffix(".pkl")
|
453 |
-
with open(id_class_output_path, "wb") as f:
|
454 |
-
pickle.dump(id_class_dict, f)
|
455 |
-
|
456 |
-
if split_id_dict is not None:
|
457 |
-
data_dict = dict()
|
458 |
-
data_dict["train"] = pu.filter_by_dict(
|
459 |
-
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
|
460 |
-
)
|
461 |
-
data_dict["test"] = pu.filter_by_dict(
|
462 |
-
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
|
463 |
-
)
|
464 |
-
train_data_output_path = (
|
465 |
-
Path(output_directory) / f"{output_prefix}_labeled_train"
|
466 |
-
).with_suffix(".dataset")
|
467 |
-
test_data_output_path = (
|
468 |
-
Path(output_directory) / f"{output_prefix}_labeled_test"
|
469 |
-
).with_suffix(".dataset")
|
470 |
-
data_dict["train"].save_to_disk(str(train_data_output_path))
|
471 |
-
data_dict["test"].save_to_disk(str(test_data_output_path))
|
472 |
-
elif (test_size is not None) and (self.classifier == "cell"):
|
473 |
-
if 1 > test_size > 0:
|
474 |
-
if attr_to_split is None:
|
475 |
-
data_dict = data.train_test_split(
|
476 |
-
test_size=test_size,
|
477 |
-
stratify_by_column=self.stratify_splits_col,
|
478 |
-
seed=42,
|
479 |
-
)
|
480 |
-
train_data_output_path = (
|
481 |
-
Path(output_directory) / f"{output_prefix}_labeled_train"
|
482 |
-
).with_suffix(".dataset")
|
483 |
-
test_data_output_path = (
|
484 |
-
Path(output_directory) / f"{output_prefix}_labeled_test"
|
485 |
-
).with_suffix(".dataset")
|
486 |
-
data_dict["train"].save_to_disk(str(train_data_output_path))
|
487 |
-
data_dict["test"].save_to_disk(str(test_data_output_path))
|
488 |
-
else:
|
489 |
-
data_dict, balance_df = cu.balance_attr_splits(
|
490 |
-
data,
|
491 |
-
attr_to_split,
|
492 |
-
attr_to_balance,
|
493 |
-
test_size,
|
494 |
-
max_trials,
|
495 |
-
pval_threshold,
|
496 |
-
self.cell_state_dict["state_key"],
|
497 |
-
self.nproc,
|
498 |
-
)
|
499 |
-
balance_df.to_csv(
|
500 |
-
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
|
501 |
-
)
|
502 |
-
train_data_output_path = (
|
503 |
-
Path(output_directory) / f"{output_prefix}_labeled_train"
|
504 |
-
).with_suffix(".dataset")
|
505 |
-
test_data_output_path = (
|
506 |
-
Path(output_directory) / f"{output_prefix}_labeled_test"
|
507 |
-
).with_suffix(".dataset")
|
508 |
-
data_dict["train"].save_to_disk(str(train_data_output_path))
|
509 |
-
data_dict["test"].save_to_disk(str(test_data_output_path))
|
510 |
-
else:
|
511 |
-
data_output_path = (
|
512 |
-
Path(output_directory) / f"{output_prefix}_labeled"
|
513 |
-
).with_suffix(".dataset")
|
514 |
-
data.save_to_disk(str(data_output_path))
|
515 |
-
print(data_output_path)
|
516 |
-
else:
|
517 |
-
data_output_path = (
|
518 |
-
Path(output_directory) / f"{output_prefix}_labeled"
|
519 |
-
).with_suffix(".dataset")
|
520 |
-
data.save_to_disk(str(data_output_path))
|
521 |
-
|
522 |
-
def train_all_data(
|
523 |
-
self,
|
524 |
-
model_directory,
|
525 |
-
prepared_input_data_file,
|
526 |
-
id_class_dict_file,
|
527 |
-
output_directory,
|
528 |
-
output_prefix,
|
529 |
-
save_eval_output=True,
|
530 |
-
gene_balance=False,
|
531 |
-
):
|
532 |
-
"""
|
533 |
-
Train cell state or gene classifier using all data.
|
534 |
-
|
535 |
-
**Parameters**
|
536 |
-
|
537 |
-
model_directory : Path
|
538 |
-
| Path to directory containing model
|
539 |
-
prepared_input_data_file : Path
|
540 |
-
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
|
541 |
-
id_class_dict_file : Path
|
542 |
-
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
543 |
-
| (dictionary of format: numerical IDs: class_labels)
|
544 |
-
output_directory : Path
|
545 |
-
| Path to directory where model and eval data will be saved
|
546 |
-
output_prefix : str
|
547 |
-
| Prefix for output files
|
548 |
-
save_eval_output : bool
|
549 |
-
| Whether to save cross-fold eval output
|
550 |
-
| Saves as pickle file of dictionary of eval metrics
|
551 |
-
gene_balance : None, bool
|
552 |
-
| Whether to automatically balance genes in training set.
|
553 |
-
| Only available for binary gene classifications.
|
554 |
-
|
555 |
-
**Output**
|
556 |
-
|
557 |
-
Returns trainer after fine-tuning with all data.
|
558 |
-
|
559 |
-
"""
|
560 |
-
|
561 |
-
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
562 |
-
logger.error(
|
563 |
-
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
564 |
-
)
|
565 |
-
raise
|
566 |
-
|
567 |
-
##### Load data and prepare output directory #####
|
568 |
-
# load numerical id to class dictionary (id:class)
|
569 |
-
with open(id_class_dict_file, "rb") as f:
|
570 |
-
id_class_dict = pickle.load(f)
|
571 |
-
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
572 |
-
|
573 |
-
# load previously filtered and prepared data
|
574 |
-
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
|
575 |
-
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
|
576 |
-
|
577 |
-
# define output directory path
|
578 |
-
current_date = datetime.datetime.now()
|
579 |
-
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
580 |
-
if output_directory[-1:] != "/": # add slash for dir if not present
|
581 |
-
output_directory = output_directory + "/"
|
582 |
-
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
|
583 |
-
subprocess.call(f"mkdir {output_dir}", shell=True)
|
584 |
-
|
585 |
-
# get number of classes for classifier
|
586 |
-
num_classes = cu.get_num_classes(id_class_dict)
|
587 |
-
|
588 |
-
if self.classifier == "gene":
|
589 |
-
targets = pu.flatten_list(self.gene_class_dict.values())
|
590 |
-
labels = pu.flatten_list(
|
591 |
-
[
|
592 |
-
[class_id_dict[label]] * len(targets)
|
593 |
-
for label, targets in self.gene_class_dict.items()
|
594 |
-
]
|
595 |
-
)
|
596 |
-
assert len(targets) == len(labels)
|
597 |
-
data = cu.prep_gene_classifier_all_data(
|
598 |
-
data, targets, labels, self.max_ncells, self.nproc, gene_balance
|
599 |
-
)
|
600 |
-
|
601 |
-
trainer = self.train_classifier(
|
602 |
-
model_directory, num_classes, data, None, output_dir
|
603 |
-
)
|
604 |
-
|
605 |
-
return trainer
|
606 |
-
|
607 |
-
def validate(
|
608 |
-
self,
|
609 |
-
model_directory,
|
610 |
-
prepared_input_data_file,
|
611 |
-
id_class_dict_file,
|
612 |
-
output_directory,
|
613 |
-
output_prefix,
|
614 |
-
split_id_dict=None,
|
615 |
-
attr_to_split=None,
|
616 |
-
attr_to_balance=None,
|
617 |
-
gene_balance=False,
|
618 |
-
max_trials=100,
|
619 |
-
pval_threshold=0.1,
|
620 |
-
save_eval_output=True,
|
621 |
-
predict_eval=True,
|
622 |
-
predict_trainer=False,
|
623 |
-
n_hyperopt_trials=0,
|
624 |
-
save_gene_split_datasets=True,
|
625 |
-
debug_gene_split_datasets=False,
|
626 |
-
):
|
627 |
-
"""
|
628 |
-
(Cross-)validate cell state or gene classifier.
|
629 |
-
|
630 |
-
**Parameters**
|
631 |
-
|
632 |
-
model_directory : Path
|
633 |
-
| Path to directory containing model
|
634 |
-
prepared_input_data_file : Path
|
635 |
-
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
|
636 |
-
id_class_dict_file : Path
|
637 |
-
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
638 |
-
| (dictionary of format: numerical IDs: class_labels)
|
639 |
-
output_directory : Path
|
640 |
-
| Path to directory where model and eval data will be saved
|
641 |
-
output_prefix : str
|
642 |
-
| Prefix for output files
|
643 |
-
split_id_dict : None, dict
|
644 |
-
| Dictionary of IDs for train and eval splits
|
645 |
-
| Three-item dictionary with keys: attr_key, train, eval
|
646 |
-
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
647 |
-
| train: list of IDs in the attr_key column to include in the train split
|
648 |
-
| eval: list of IDs in the attr_key column to include in the eval split
|
649 |
-
| For example: {"attr_key": "individual",
|
650 |
-
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
651 |
-
| "eval": ["patient5", "patient6"]}
|
652 |
-
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
|
653 |
-
attr_to_split : None, str
|
654 |
-
| Key for attribute on which to split data while balancing potential confounders
|
655 |
-
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
656 |
-
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
|
657 |
-
attr_to_balance : None, list
|
658 |
-
| List of attribute keys on which to balance data while splitting on attr_to_split
|
659 |
-
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
660 |
-
gene_balance : None, bool
|
661 |
-
| Whether to automatically balance genes in training set.
|
662 |
-
| Only available for binary gene classifications.
|
663 |
-
max_trials : None, int
|
664 |
-
| Maximum number of trials of random splitting to try to achieve balanced other attribute
|
665 |
-
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
|
666 |
-
pval_threshold : None, float
|
667 |
-
| P-value threshold to use for attribute balancing across splits
|
668 |
-
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
669 |
-
save_eval_output : bool
|
670 |
-
| Whether to save cross-fold eval output
|
671 |
-
| Saves as pickle file of dictionary of eval metrics
|
672 |
-
predict_eval : bool
|
673 |
-
| Whether or not to save eval predictions
|
674 |
-
| Saves as a pickle file of self.evaluate predictions
|
675 |
-
predict_trainer : bool
|
676 |
-
| Whether or not to save eval predictions from trainer
|
677 |
-
| Saves as a pickle file of trainer predictions
|
678 |
-
n_hyperopt_trials : int
|
679 |
-
| Number of trials to run for hyperparameter optimization
|
680 |
-
| If 0, will not optimize hyperparameters
|
681 |
-
save_gene_split_datasets : bool
|
682 |
-
| Whether or not to save train, valid, and test gene-labeled datasets
|
683 |
-
"""
|
684 |
-
if self.num_crossval_splits == 0:
|
685 |
-
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
686 |
-
raise
|
687 |
-
|
688 |
-
if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
|
689 |
-
logger.error(
|
690 |
-
"Automatically balancing gene sets for training is only available for binary gene classifications."
|
691 |
-
)
|
692 |
-
raise
|
693 |
-
|
694 |
-
# ensure number of genes in each class is > 5 if validating model
|
695 |
-
if self.classifier == "gene":
|
696 |
-
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
697 |
-
if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
|
698 |
-
logger.error(
|
699 |
-
f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
|
700 |
-
)
|
701 |
-
raise
|
702 |
-
|
703 |
-
##### Load data and prepare output directory #####
|
704 |
-
# load numerical id to class dictionary (id:class)
|
705 |
-
with open(id_class_dict_file, "rb") as f:
|
706 |
-
id_class_dict = pickle.load(f)
|
707 |
-
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
708 |
-
|
709 |
-
# load previously filtered and prepared data
|
710 |
-
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
|
711 |
-
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
|
712 |
-
|
713 |
-
# define output directory path
|
714 |
-
current_date = datetime.datetime.now()
|
715 |
-
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
716 |
-
if output_directory[-1:] != "/": # add slash for dir if not present
|
717 |
-
output_directory = output_directory + "/"
|
718 |
-
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
|
719 |
-
subprocess.call(f"mkdir {output_dir}", shell=True)
|
720 |
-
|
721 |
-
# get number of classes for classifier
|
722 |
-
num_classes = cu.get_num_classes(id_class_dict)
|
723 |
-
|
724 |
-
##### (Cross-)validate the model #####
|
725 |
-
results = []
|
726 |
-
all_conf_mat = np.zeros((num_classes, num_classes))
|
727 |
-
iteration_num = 1
|
728 |
-
if self.classifier == "cell":
|
729 |
-
for i in trange(self.num_crossval_splits):
|
730 |
-
print(
|
731 |
-
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
732 |
-
)
|
733 |
-
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
734 |
-
if self.num_crossval_splits == 1:
|
735 |
-
# single 1-eval_size:eval_size split
|
736 |
-
if split_id_dict is not None:
|
737 |
-
data_dict = dict()
|
738 |
-
data_dict["train"] = pu.filter_by_dict(
|
739 |
-
data,
|
740 |
-
{split_id_dict["attr_key"]: split_id_dict["train"]},
|
741 |
-
self.nproc,
|
742 |
-
)
|
743 |
-
data_dict["test"] = pu.filter_by_dict(
|
744 |
-
data,
|
745 |
-
{split_id_dict["attr_key"]: split_id_dict["eval"]},
|
746 |
-
self.nproc,
|
747 |
-
)
|
748 |
-
elif attr_to_split is not None:
|
749 |
-
data_dict, balance_df = cu.balance_attr_splits(
|
750 |
-
data,
|
751 |
-
attr_to_split,
|
752 |
-
attr_to_balance,
|
753 |
-
self.eval_size,
|
754 |
-
max_trials,
|
755 |
-
pval_threshold,
|
756 |
-
self.cell_state_dict["state_key"],
|
757 |
-
self.nproc,
|
758 |
-
)
|
759 |
-
|
760 |
-
balance_df.to_csv(
|
761 |
-
f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
|
762 |
-
)
|
763 |
-
else:
|
764 |
-
data_dict = data.train_test_split(
|
765 |
-
test_size=self.eval_size,
|
766 |
-
stratify_by_column=self.stratify_splits_col,
|
767 |
-
seed=42,
|
768 |
-
)
|
769 |
-
train_data = data_dict["train"]
|
770 |
-
eval_data = data_dict["test"]
|
771 |
-
else:
|
772 |
-
# 5-fold cross-validate
|
773 |
-
num_cells = len(data)
|
774 |
-
fifth_cells = int(np.floor(num_cells * 0.2))
|
775 |
-
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
776 |
-
start = i * fifth_cells
|
777 |
-
end = start + num_eval
|
778 |
-
eval_indices = [j for j in range(start, end)]
|
779 |
-
train_indices = [
|
780 |
-
j for j in range(num_cells) if j not in eval_indices
|
781 |
-
]
|
782 |
-
eval_data = data.select(eval_indices)
|
783 |
-
train_data = data.select(train_indices)
|
784 |
-
if n_hyperopt_trials == 0:
|
785 |
-
trainer = self.train_classifier(
|
786 |
-
model_directory,
|
787 |
-
num_classes,
|
788 |
-
train_data,
|
789 |
-
eval_data,
|
790 |
-
ksplit_output_dir,
|
791 |
-
predict_trainer,
|
792 |
-
)
|
793 |
-
else:
|
794 |
-
trainer = self.hyperopt_classifier(
|
795 |
-
model_directory,
|
796 |
-
num_classes,
|
797 |
-
train_data,
|
798 |
-
eval_data,
|
799 |
-
ksplit_output_dir,
|
800 |
-
n_trials=n_hyperopt_trials,
|
801 |
-
)
|
802 |
-
if iteration_num == self.num_crossval_splits:
|
803 |
-
return
|
804 |
-
else:
|
805 |
-
iteration_num = iteration_num + 1
|
806 |
-
continue
|
807 |
-
|
808 |
-
result = self.evaluate_model(
|
809 |
-
trainer.model,
|
810 |
-
num_classes,
|
811 |
-
id_class_dict,
|
812 |
-
eval_data,
|
813 |
-
predict_eval,
|
814 |
-
ksplit_output_dir,
|
815 |
-
output_prefix,
|
816 |
-
)
|
817 |
-
results += [result]
|
818 |
-
all_conf_mat = all_conf_mat + result["conf_mat"]
|
819 |
-
iteration_num = iteration_num + 1
|
820 |
-
|
821 |
-
elif self.classifier == "gene":
|
822 |
-
# set up (cross-)validation splits
|
823 |
-
targets = pu.flatten_list(self.gene_class_dict.values())
|
824 |
-
labels = pu.flatten_list(
|
825 |
-
[
|
826 |
-
[class_id_dict[label]] * len(targets)
|
827 |
-
for label, targets in self.gene_class_dict.items()
|
828 |
-
]
|
829 |
-
)
|
830 |
-
assert len(targets) == len(labels)
|
831 |
-
n_splits = int(1 / (1 - self.train_size))
|
832 |
-
skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
|
833 |
-
# (Cross-)validate
|
834 |
-
test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
|
835 |
-
for train_index, eval_index, test_index in tqdm(
|
836 |
-
skf.split(targets, labels, test_ratio)
|
837 |
-
):
|
838 |
-
print(
|
839 |
-
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
840 |
-
)
|
841 |
-
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
842 |
-
# filter data for examples containing classes for this split
|
843 |
-
# subsample to max_ncells and relabel data in column "labels"
|
844 |
-
train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
|
845 |
-
data,
|
846 |
-
targets,
|
847 |
-
labels,
|
848 |
-
train_index,
|
849 |
-
eval_index,
|
850 |
-
self.max_ncells,
|
851 |
-
iteration_num,
|
852 |
-
self.nproc,
|
853 |
-
gene_balance,
|
854 |
-
)
|
855 |
-
|
856 |
-
if save_gene_split_datasets is True:
|
857 |
-
for split_name in ["train", "valid"]:
|
858 |
-
labeled_dataset_output_path = (
|
859 |
-
Path(output_dir)
|
860 |
-
/ f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
|
861 |
-
).with_suffix(".dataset")
|
862 |
-
if split_name == "train":
|
863 |
-
train_data.save_to_disk(str(labeled_dataset_output_path))
|
864 |
-
elif split_name == "valid":
|
865 |
-
eval_data.save_to_disk(str(labeled_dataset_output_path))
|
866 |
-
|
867 |
-
if self.oos_test_size > 0:
|
868 |
-
test_data = cu.prep_gene_classifier_split(
|
869 |
-
data,
|
870 |
-
targets,
|
871 |
-
labels,
|
872 |
-
test_index,
|
873 |
-
"test",
|
874 |
-
self.max_ncells,
|
875 |
-
iteration_num,
|
876 |
-
self.nproc,
|
877 |
-
)
|
878 |
-
if save_gene_split_datasets is True:
|
879 |
-
test_labeled_dataset_output_path = (
|
880 |
-
Path(output_dir)
|
881 |
-
/ f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
|
882 |
-
).with_suffix(".dataset")
|
883 |
-
test_data.save_to_disk(str(test_labeled_dataset_output_path))
|
884 |
-
if debug_gene_split_datasets is True:
|
885 |
-
logger.error(
|
886 |
-
"Exiting after saving gene split datasets given debug_gene_split_datasets = True."
|
887 |
-
)
|
888 |
-
raise
|
889 |
-
if n_hyperopt_trials == 0:
|
890 |
-
trainer = self.train_classifier(
|
891 |
-
model_directory,
|
892 |
-
num_classes,
|
893 |
-
train_data,
|
894 |
-
eval_data,
|
895 |
-
ksplit_output_dir,
|
896 |
-
predict_trainer,
|
897 |
-
)
|
898 |
-
result = self.evaluate_model(
|
899 |
-
trainer.model,
|
900 |
-
num_classes,
|
901 |
-
id_class_dict,
|
902 |
-
eval_data,
|
903 |
-
predict_eval,
|
904 |
-
ksplit_output_dir,
|
905 |
-
output_prefix,
|
906 |
-
)
|
907 |
-
else:
|
908 |
-
trainer = self.hyperopt_classifier(
|
909 |
-
model_directory,
|
910 |
-
num_classes,
|
911 |
-
train_data,
|
912 |
-
eval_data,
|
913 |
-
ksplit_output_dir,
|
914 |
-
n_trials=n_hyperopt_trials,
|
915 |
-
)
|
916 |
-
|
917 |
-
model = cu.load_best_model(
|
918 |
-
ksplit_output_dir, self.model_type, num_classes
|
919 |
-
)
|
920 |
-
|
921 |
-
if self.oos_test_size > 0:
|
922 |
-
result = self.evaluate_model(
|
923 |
-
model,
|
924 |
-
num_classes,
|
925 |
-
id_class_dict,
|
926 |
-
test_data,
|
927 |
-
predict_eval,
|
928 |
-
ksplit_output_dir,
|
929 |
-
output_prefix,
|
930 |
-
)
|
931 |
-
else:
|
932 |
-
if iteration_num == self.num_crossval_splits:
|
933 |
-
return
|
934 |
-
else:
|
935 |
-
iteration_num = iteration_num + 1
|
936 |
-
continue
|
937 |
-
results += [result]
|
938 |
-
all_conf_mat = all_conf_mat + result["conf_mat"]
|
939 |
-
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
940 |
-
if iteration_num == self.num_crossval_splits:
|
941 |
-
break
|
942 |
-
iteration_num = iteration_num + 1
|
943 |
-
|
944 |
-
all_conf_mat_df = pd.DataFrame(
|
945 |
-
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
|
946 |
-
)
|
947 |
-
all_metrics = {
|
948 |
-
"conf_matrix": all_conf_mat_df,
|
949 |
-
"macro_f1": [result["macro_f1"] for result in results],
|
950 |
-
"acc": [result["acc"] for result in results],
|
951 |
-
}
|
952 |
-
all_roc_metrics = None # roc metrics not reported for multiclass
|
953 |
-
if num_classes == 2:
|
954 |
-
mean_fpr = np.linspace(0, 1, 100)
|
955 |
-
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
|
956 |
-
all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
|
957 |
-
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
958 |
-
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
959 |
-
all_tpr, all_roc_auc, all_tpr_wt
|
960 |
-
)
|
961 |
-
all_roc_metrics = {
|
962 |
-
"mean_tpr": mean_tpr,
|
963 |
-
"mean_fpr": mean_fpr,
|
964 |
-
"all_roc_auc": all_roc_auc,
|
965 |
-
"roc_auc": roc_auc,
|
966 |
-
"roc_auc_sd": roc_auc_sd,
|
967 |
-
}
|
968 |
-
all_metrics["all_roc_metrics"] = all_roc_metrics
|
969 |
-
if save_eval_output is True:
|
970 |
-
eval_metrics_output_path = (
|
971 |
-
Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
|
972 |
-
).with_suffix(".pkl")
|
973 |
-
with open(eval_metrics_output_path, "wb") as f:
|
974 |
-
pickle.dump(all_metrics, f)
|
975 |
-
|
976 |
-
return all_metrics
|
977 |
-
|
978 |
-
def hyperopt_classifier(
|
979 |
-
self,
|
980 |
-
model_directory,
|
981 |
-
num_classes,
|
982 |
-
train_data,
|
983 |
-
eval_data,
|
984 |
-
output_directory,
|
985 |
-
n_trials=100,
|
986 |
-
):
|
987 |
-
"""
|
988 |
-
Fine-tune model for cell state or gene classification.
|
989 |
-
|
990 |
-
**Parameters**
|
991 |
-
|
992 |
-
model_directory : Path
|
993 |
-
| Path to directory containing model
|
994 |
-
num_classes : int
|
995 |
-
| Number of classes for classifier
|
996 |
-
train_data : Dataset
|
997 |
-
| Loaded training .dataset input
|
998 |
-
| For cell classifier, labels in column "label".
|
999 |
-
| For gene classifier, labels in column "labels".
|
1000 |
-
eval_data : None, Dataset
|
1001 |
-
| (Optional) Loaded evaluation .dataset input
|
1002 |
-
| For cell classifier, labels in column "label".
|
1003 |
-
| For gene classifier, labels in column "labels".
|
1004 |
-
output_directory : Path
|
1005 |
-
| Path to directory where fine-tuned model will be saved
|
1006 |
-
n_trials : int
|
1007 |
-
| Number of trials to run for hyperparameter optimization
|
1008 |
-
"""
|
1009 |
-
|
1010 |
-
# initiate runtime environment for raytune
|
1011 |
-
import ray
|
1012 |
-
from ray import tune
|
1013 |
-
from ray.tune.search.hyperopt import HyperOptSearch
|
1014 |
-
|
1015 |
-
ray.shutdown() # engage new ray session
|
1016 |
-
ray.init()
|
1017 |
-
|
1018 |
-
##### Validate and prepare data #####
|
1019 |
-
train_data, eval_data = cu.validate_and_clean_cols(
|
1020 |
-
train_data, eval_data, self.classifier
|
1021 |
-
)
|
1022 |
-
|
1023 |
-
if (self.no_eval is True) and (eval_data is not None):
|
1024 |
-
logger.warning(
|
1025 |
-
"no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
|
1026 |
-
)
|
1027 |
-
|
1028 |
-
# ensure not overwriting previously saved model
|
1029 |
-
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
1030 |
-
if os.path.isfile(saved_model_test) is True:
|
1031 |
-
logger.error("Model already saved to this designated output directory.")
|
1032 |
-
raise
|
1033 |
-
# make output directory
|
1034 |
-
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1035 |
-
|
1036 |
-
##### Load model and training args #####
|
1037 |
-
model = pu.load_model(
|
1038 |
-
self.model_type,
|
1039 |
-
num_classes,
|
1040 |
-
model_directory,
|
1041 |
-
"train",
|
1042 |
-
quantize=self.quantize,
|
1043 |
-
)
|
1044 |
-
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1045 |
-
model, self.classifier, train_data, output_directory
|
1046 |
-
)
|
1047 |
-
del model
|
1048 |
-
|
1049 |
-
if self.training_args is not None:
|
1050 |
-
def_training_args.update(self.training_args)
|
1051 |
-
logging_steps = round(
|
1052 |
-
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
1053 |
-
)
|
1054 |
-
def_training_args["logging_steps"] = logging_steps
|
1055 |
-
def_training_args["output_dir"] = output_directory
|
1056 |
-
if eval_data is None:
|
1057 |
-
def_training_args["evaluation_strategy"] = "no"
|
1058 |
-
def_training_args["load_best_model_at_end"] = False
|
1059 |
-
def_training_args.update(
|
1060 |
-
{"save_strategy": "epoch", "save_total_limit": 1}
|
1061 |
-
) # only save last model for each run
|
1062 |
-
training_args_init = TrainingArguments(**def_training_args)
|
1063 |
-
|
1064 |
-
##### Fine-tune the model #####
|
1065 |
-
# define the data collator
|
1066 |
-
if self.classifier == "cell":
|
1067 |
-
data_collator = DataCollatorForCellClassification(
|
1068 |
-
token_dictionary=self.gene_token_dict
|
1069 |
-
)
|
1070 |
-
elif self.classifier == "gene":
|
1071 |
-
data_collator = DataCollatorForGeneClassification(
|
1072 |
-
token_dictionary=self.gene_token_dict
|
1073 |
-
)
|
1074 |
-
|
1075 |
-
# define function to initiate model
|
1076 |
-
def model_init():
|
1077 |
-
model = pu.load_model(
|
1078 |
-
self.model_type,
|
1079 |
-
num_classes,
|
1080 |
-
model_directory,
|
1081 |
-
"train",
|
1082 |
-
quantize=self.quantize,
|
1083 |
-
)
|
1084 |
-
|
1085 |
-
if self.freeze_layers is not None:
|
1086 |
-
def_freeze_layers = self.freeze_layers
|
1087 |
-
|
1088 |
-
if def_freeze_layers > 0:
|
1089 |
-
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
1090 |
-
for module in modules_to_freeze:
|
1091 |
-
for param in module.parameters():
|
1092 |
-
param.requires_grad = False
|
1093 |
-
|
1094 |
-
if self.quantize is False:
|
1095 |
-
model = model.to("cuda:0")
|
1096 |
-
return model
|
1097 |
-
|
1098 |
-
# create the trainer
|
1099 |
-
trainer = Trainer(
|
1100 |
-
model_init=model_init,
|
1101 |
-
args=training_args_init,
|
1102 |
-
data_collator=data_collator,
|
1103 |
-
train_dataset=train_data,
|
1104 |
-
eval_dataset=eval_data,
|
1105 |
-
compute_metrics=cu.compute_metrics,
|
1106 |
-
)
|
1107 |
-
|
1108 |
-
# specify raytune hyperparameter search space
|
1109 |
-
if self.ray_config is None:
|
1110 |
-
logger.warning(
|
1111 |
-
"No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
|
1112 |
-
)
|
1113 |
-
def_ray_config = {
|
1114 |
-
"num_train_epochs": tune.choice([1]),
|
1115 |
-
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
1116 |
-
"weight_decay": tune.uniform(0.0, 0.3),
|
1117 |
-
"lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
|
1118 |
-
"warmup_steps": tune.uniform(100, 2000),
|
1119 |
-
"seed": tune.uniform(0, 100),
|
1120 |
-
"per_device_train_batch_size": tune.choice(
|
1121 |
-
[def_training_args["per_device_train_batch_size"]]
|
1122 |
-
),
|
1123 |
-
}
|
1124 |
-
|
1125 |
-
hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
|
1126 |
-
|
1127 |
-
# optimize hyperparameters
|
1128 |
-
trainer.hyperparameter_search(
|
1129 |
-
direction="maximize",
|
1130 |
-
backend="ray",
|
1131 |
-
resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
|
1132 |
-
hp_space=lambda _: def_ray_config
|
1133 |
-
if self.ray_config is None
|
1134 |
-
else self.ray_config,
|
1135 |
-
search_alg=hyperopt_search,
|
1136 |
-
n_trials=n_trials, # number of trials
|
1137 |
-
progress_reporter=tune.CLIReporter(
|
1138 |
-
max_report_frequency=600,
|
1139 |
-
sort_by_metric=True,
|
1140 |
-
max_progress_rows=n_trials,
|
1141 |
-
mode="max",
|
1142 |
-
metric="eval_macro_f1",
|
1143 |
-
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1144 |
-
),
|
1145 |
-
storage_path=output_directory,
|
1146 |
-
)
|
1147 |
-
|
1148 |
-
return trainer
|
1149 |
-
|
1150 |
-
def train_classifier(
|
1151 |
-
self,
|
1152 |
-
model_directory,
|
1153 |
-
num_classes,
|
1154 |
-
train_data,
|
1155 |
-
eval_data,
|
1156 |
-
output_directory,
|
1157 |
-
predict=False,
|
1158 |
-
):
|
1159 |
-
"""
|
1160 |
-
Fine-tune model for cell state or gene classification.
|
1161 |
-
|
1162 |
-
**Parameters**
|
1163 |
-
|
1164 |
-
model_directory : Path
|
1165 |
-
| Path to directory containing model
|
1166 |
-
num_classes : int
|
1167 |
-
| Number of classes for classifier
|
1168 |
-
train_data : Dataset
|
1169 |
-
| Loaded training .dataset input
|
1170 |
-
| For cell classifier, labels in column "label".
|
1171 |
-
| For gene classifier, labels in column "labels".
|
1172 |
-
eval_data : None, Dataset
|
1173 |
-
| (Optional) Loaded evaluation .dataset input
|
1174 |
-
| For cell classifier, labels in column "label".
|
1175 |
-
| For gene classifier, labels in column "labels".
|
1176 |
-
output_directory : Path
|
1177 |
-
| Path to directory where fine-tuned model will be saved
|
1178 |
-
predict : bool
|
1179 |
-
| Whether or not to save eval predictions from trainer
|
1180 |
-
"""
|
1181 |
-
|
1182 |
-
##### Validate and prepare data #####
|
1183 |
-
train_data, eval_data = cu.validate_and_clean_cols(
|
1184 |
-
train_data, eval_data, self.classifier
|
1185 |
-
)
|
1186 |
-
|
1187 |
-
if (self.no_eval is True) and (eval_data is not None):
|
1188 |
-
logger.warning(
|
1189 |
-
"no_eval set to True; model will be trained without evaluation."
|
1190 |
-
)
|
1191 |
-
eval_data = None
|
1192 |
-
|
1193 |
-
if (self.classifier == "gene") and (predict is True):
|
1194 |
-
logger.warning(
|
1195 |
-
"Predictions during training not currently available for gene classifiers; setting predict to False."
|
1196 |
-
)
|
1197 |
-
predict = False
|
1198 |
-
|
1199 |
-
# ensure not overwriting previously saved model
|
1200 |
-
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
1201 |
-
if os.path.isfile(saved_model_test) is True:
|
1202 |
-
logger.error("Model already saved to this designated output directory.")
|
1203 |
-
raise
|
1204 |
-
# make output directory
|
1205 |
-
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1206 |
-
|
1207 |
-
##### Load model and training args #####
|
1208 |
-
model = pu.load_model(
|
1209 |
-
self.model_type,
|
1210 |
-
num_classes,
|
1211 |
-
model_directory,
|
1212 |
-
"train",
|
1213 |
-
quantize=self.quantize,
|
1214 |
-
)
|
1215 |
-
|
1216 |
-
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1217 |
-
model, self.classifier, train_data, output_directory
|
1218 |
-
)
|
1219 |
-
|
1220 |
-
if self.training_args is not None:
|
1221 |
-
def_training_args.update(self.training_args)
|
1222 |
-
logging_steps = round(
|
1223 |
-
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
1224 |
-
)
|
1225 |
-
def_training_args["logging_steps"] = logging_steps
|
1226 |
-
def_training_args["output_dir"] = output_directory
|
1227 |
-
if eval_data is None:
|
1228 |
-
def_training_args["evaluation_strategy"] = "no"
|
1229 |
-
def_training_args["load_best_model_at_end"] = False
|
1230 |
-
training_args_init = TrainingArguments(**def_training_args)
|
1231 |
-
|
1232 |
-
if self.freeze_layers is not None:
|
1233 |
-
def_freeze_layers = self.freeze_layers
|
1234 |
-
|
1235 |
-
if def_freeze_layers > 0:
|
1236 |
-
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
1237 |
-
for module in modules_to_freeze:
|
1238 |
-
for param in module.parameters():
|
1239 |
-
param.requires_grad = False
|
1240 |
-
|
1241 |
-
##### Fine-tune the model #####
|
1242 |
-
# define the data collator
|
1243 |
-
if self.classifier == "cell":
|
1244 |
-
data_collator = DataCollatorForCellClassification(
|
1245 |
-
token_dictionary=self.gene_token_dict
|
1246 |
-
)
|
1247 |
-
elif self.classifier == "gene":
|
1248 |
-
data_collator = DataCollatorForGeneClassification(
|
1249 |
-
token_dictionary=self.gene_token_dict
|
1250 |
-
)
|
1251 |
-
|
1252 |
-
# create the trainer
|
1253 |
-
trainer = Trainer(
|
1254 |
-
model=model,
|
1255 |
-
args=training_args_init,
|
1256 |
-
data_collator=data_collator,
|
1257 |
-
train_dataset=train_data,
|
1258 |
-
eval_dataset=eval_data,
|
1259 |
-
compute_metrics=cu.compute_metrics,
|
1260 |
-
)
|
1261 |
-
|
1262 |
-
# train the classifier
|
1263 |
-
trainer.train()
|
1264 |
-
trainer.save_model(output_directory)
|
1265 |
-
if predict is True:
|
1266 |
-
# make eval predictions and save predictions and metrics
|
1267 |
-
predictions = trainer.predict(eval_data)
|
1268 |
-
prediction_output_path = f"{output_directory}/predictions.pkl"
|
1269 |
-
with open(prediction_output_path, "wb") as f:
|
1270 |
-
pickle.dump(predictions, f)
|
1271 |
-
trainer.save_metrics("eval", predictions.metrics)
|
1272 |
-
return trainer
|
1273 |
-
|
1274 |
-
def evaluate_model(
|
1275 |
-
self,
|
1276 |
-
model,
|
1277 |
-
num_classes,
|
1278 |
-
id_class_dict,
|
1279 |
-
eval_data,
|
1280 |
-
predict=False,
|
1281 |
-
output_directory=None,
|
1282 |
-
output_prefix=None,
|
1283 |
-
):
|
1284 |
-
"""
|
1285 |
-
Evaluate the fine-tuned model.
|
1286 |
-
|
1287 |
-
**Parameters**
|
1288 |
-
|
1289 |
-
model : nn.Module
|
1290 |
-
| Loaded fine-tuned model (e.g. trainer.model)
|
1291 |
-
num_classes : int
|
1292 |
-
| Number of classes for classifier
|
1293 |
-
id_class_dict : dict
|
1294 |
-
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1295 |
-
| (dictionary of format: numerical IDs: class_labels)
|
1296 |
-
eval_data : Dataset
|
1297 |
-
| Loaded evaluation .dataset input
|
1298 |
-
predict : bool
|
1299 |
-
| Whether or not to save eval predictions
|
1300 |
-
output_directory : Path
|
1301 |
-
| Path to directory where eval data will be saved
|
1302 |
-
output_prefix : str
|
1303 |
-
| Prefix for output files
|
1304 |
-
"""
|
1305 |
-
|
1306 |
-
##### Evaluate the model #####
|
1307 |
-
labels = id_class_dict.keys()
|
1308 |
-
y_pred, y_true, logits_list = eu.classifier_predict(
|
1309 |
-
model, self.classifier, eval_data, self.forward_batch_size
|
1310 |
-
)
|
1311 |
-
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
1312 |
-
y_pred, y_true, logits_list, num_classes, labels
|
1313 |
-
)
|
1314 |
-
if predict is True:
|
1315 |
-
pred_dict = {
|
1316 |
-
"pred_ids": y_pred,
|
1317 |
-
"label_ids": y_true,
|
1318 |
-
"predictions": logits_list,
|
1319 |
-
}
|
1320 |
-
pred_dict_output_path = (
|
1321 |
-
Path(output_directory) / f"{output_prefix}_pred_dict"
|
1322 |
-
).with_suffix(".pkl")
|
1323 |
-
with open(pred_dict_output_path, "wb") as f:
|
1324 |
-
pickle.dump(pred_dict, f)
|
1325 |
-
return {
|
1326 |
-
"conf_mat": conf_mat,
|
1327 |
-
"macro_f1": macro_f1,
|
1328 |
-
"acc": acc,
|
1329 |
-
"roc_metrics": roc_metrics,
|
1330 |
-
}
|
1331 |
-
|
1332 |
-
def evaluate_saved_model(
|
1333 |
-
self,
|
1334 |
-
model_directory,
|
1335 |
-
id_class_dict_file,
|
1336 |
-
test_data_file,
|
1337 |
-
output_directory,
|
1338 |
-
output_prefix,
|
1339 |
-
predict=True,
|
1340 |
-
):
|
1341 |
-
"""
|
1342 |
-
Evaluate the fine-tuned model.
|
1343 |
-
|
1344 |
-
**Parameters**
|
1345 |
-
|
1346 |
-
model_directory : Path
|
1347 |
-
| Path to directory containing model
|
1348 |
-
id_class_dict_file : Path
|
1349 |
-
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1350 |
-
| (dictionary of format: numerical IDs: class_labels)
|
1351 |
-
test_data_file : Path
|
1352 |
-
| Path to directory containing test .dataset
|
1353 |
-
output_directory : Path
|
1354 |
-
| Path to directory where eval data will be saved
|
1355 |
-
output_prefix : str
|
1356 |
-
| Prefix for output files
|
1357 |
-
predict : bool
|
1358 |
-
| Whether or not to save eval predictions
|
1359 |
-
"""
|
1360 |
-
|
1361 |
-
# load numerical id to class dictionary (id:class)
|
1362 |
-
with open(id_class_dict_file, "rb") as f:
|
1363 |
-
id_class_dict = pickle.load(f)
|
1364 |
-
|
1365 |
-
# get number of classes for classifier
|
1366 |
-
num_classes = cu.get_num_classes(id_class_dict)
|
1367 |
-
|
1368 |
-
# load previously filtered and prepared data
|
1369 |
-
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1370 |
-
|
1371 |
-
# load previously fine-tuned model
|
1372 |
-
model = pu.load_model(
|
1373 |
-
self.model_type,
|
1374 |
-
num_classes,
|
1375 |
-
model_directory,
|
1376 |
-
"eval",
|
1377 |
-
quantize=self.quantize,
|
1378 |
-
)
|
1379 |
-
|
1380 |
-
# evaluate the model
|
1381 |
-
result = self.evaluate_model(
|
1382 |
-
model,
|
1383 |
-
num_classes,
|
1384 |
-
id_class_dict,
|
1385 |
-
test_data,
|
1386 |
-
predict=predict,
|
1387 |
-
output_directory=output_directory,
|
1388 |
-
output_prefix=output_prefix,
|
1389 |
-
)
|
1390 |
-
|
1391 |
-
all_conf_mat_df = pd.DataFrame(
|
1392 |
-
result["conf_mat"],
|
1393 |
-
columns=id_class_dict.values(),
|
1394 |
-
index=id_class_dict.values(),
|
1395 |
-
)
|
1396 |
-
all_metrics = {
|
1397 |
-
"conf_matrix": all_conf_mat_df,
|
1398 |
-
"macro_f1": result["macro_f1"],
|
1399 |
-
"acc": result["acc"],
|
1400 |
-
}
|
1401 |
-
all_roc_metrics = None # roc metrics not reported for multiclass
|
1402 |
-
|
1403 |
-
if num_classes == 2:
|
1404 |
-
mean_fpr = np.linspace(0, 1, 100)
|
1405 |
-
mean_tpr = result["roc_metrics"]["interp_tpr"]
|
1406 |
-
all_roc_auc = result["roc_metrics"]["auc"]
|
1407 |
-
all_roc_metrics = {
|
1408 |
-
"mean_tpr": mean_tpr,
|
1409 |
-
"mean_fpr": mean_fpr,
|
1410 |
-
"all_roc_auc": all_roc_auc,
|
1411 |
-
}
|
1412 |
-
all_metrics["all_roc_metrics"] = all_roc_metrics
|
1413 |
-
test_metrics_output_path = (
|
1414 |
-
Path(output_directory) / f"{output_prefix}_test_metrics_dict"
|
1415 |
-
).with_suffix(".pkl")
|
1416 |
-
with open(test_metrics_output_path, "wb") as f:
|
1417 |
-
pickle.dump(all_metrics, f)
|
1418 |
-
|
1419 |
-
return all_metrics
|
1420 |
-
|
1421 |
-
def plot_conf_mat(
|
1422 |
-
self,
|
1423 |
-
conf_mat_dict,
|
1424 |
-
output_directory,
|
1425 |
-
output_prefix,
|
1426 |
-
custom_class_order=None,
|
1427 |
-
):
|
1428 |
-
"""
|
1429 |
-
Plot confusion matrix results of evaluating the fine-tuned model.
|
1430 |
-
|
1431 |
-
**Parameters**
|
1432 |
-
|
1433 |
-
conf_mat_dict : dict
|
1434 |
-
| Dictionary of model_name : confusion_matrix_DataFrame
|
1435 |
-
| (all_metrics["conf_matrix"] from self.validate)
|
1436 |
-
output_directory : Path
|
1437 |
-
| Path to directory where plots will be saved
|
1438 |
-
output_prefix : str
|
1439 |
-
| Prefix for output file
|
1440 |
-
custom_class_order : None, list
|
1441 |
-
| List of classes in custom order for plots.
|
1442 |
-
| Same order will be used for all models.
|
1443 |
-
"""
|
1444 |
-
|
1445 |
-
for model_name in conf_mat_dict.keys():
|
1446 |
-
eu.plot_confusion_matrix(
|
1447 |
-
conf_mat_dict[model_name],
|
1448 |
-
model_name,
|
1449 |
-
output_directory,
|
1450 |
-
output_prefix,
|
1451 |
-
custom_class_order,
|
1452 |
-
)
|
1453 |
-
|
1454 |
-
def plot_roc(
|
1455 |
-
self,
|
1456 |
-
roc_metric_dict,
|
1457 |
-
model_style_dict,
|
1458 |
-
title,
|
1459 |
-
output_directory,
|
1460 |
-
output_prefix,
|
1461 |
-
):
|
1462 |
-
"""
|
1463 |
-
Plot ROC curve results of evaluating the fine-tuned model.
|
1464 |
-
|
1465 |
-
**Parameters**
|
1466 |
-
|
1467 |
-
roc_metric_dict : dict
|
1468 |
-
| Dictionary of model_name : roc_metrics
|
1469 |
-
| (all_metrics["all_roc_metrics"] from self.validate)
|
1470 |
-
model_style_dict : dict[dict]
|
1471 |
-
| Dictionary of model_name : dictionary of style_attribute : style
|
1472 |
-
| where style includes color and linestyle
|
1473 |
-
| e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
|
1474 |
-
title : str
|
1475 |
-
| Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
|
1476 |
-
output_directory : Path
|
1477 |
-
| Path to directory where plots will be saved
|
1478 |
-
output_prefix : str
|
1479 |
-
| Prefix for output file
|
1480 |
-
"""
|
1481 |
-
|
1482 |
-
eu.plot_ROC(
|
1483 |
-
roc_metric_dict, model_style_dict, title, output_directory, output_prefix
|
1484 |
-
)
|
1485 |
-
|
1486 |
-
def plot_predictions(
|
1487 |
-
self,
|
1488 |
-
predictions_file,
|
1489 |
-
id_class_dict_file,
|
1490 |
-
title,
|
1491 |
-
output_directory,
|
1492 |
-
output_prefix,
|
1493 |
-
custom_class_order=None,
|
1494 |
-
kwargs_dict=None,
|
1495 |
-
):
|
1496 |
-
"""
|
1497 |
-
Plot prediction results of evaluating the fine-tuned model.
|
1498 |
-
|
1499 |
-
**Parameters**
|
1500 |
-
|
1501 |
-
predictions_file : path
|
1502 |
-
| Path of model predictions output to plot
|
1503 |
-
| (saved output from self.validate if predict_eval=True)
|
1504 |
-
| (or saved output from self.evaluate_saved_model)
|
1505 |
-
id_class_dict_file : Path
|
1506 |
-
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1507 |
-
| (dictionary of format: numerical IDs: class_labels)
|
1508 |
-
title : str
|
1509 |
-
| Title for legend containing class labels.
|
1510 |
-
output_directory : Path
|
1511 |
-
| Path to directory where plots will be saved
|
1512 |
-
output_prefix : str
|
1513 |
-
| Prefix for output file
|
1514 |
-
custom_class_order : None, list
|
1515 |
-
| List of classes in custom order for plots.
|
1516 |
-
| Same order will be used for all models.
|
1517 |
-
kwargs_dict : None, dict
|
1518 |
-
| Dictionary of kwargs to pass to plotting function.
|
1519 |
-
"""
|
1520 |
-
# load predictions
|
1521 |
-
with open(predictions_file, "rb") as f:
|
1522 |
-
predictions = pickle.load(f)
|
1523 |
-
|
1524 |
-
# load numerical id to class dictionary (id:class)
|
1525 |
-
with open(id_class_dict_file, "rb") as f:
|
1526 |
-
id_class_dict = pickle.load(f)
|
1527 |
-
|
1528 |
-
if isinstance(predictions, dict):
|
1529 |
-
if all(
|
1530 |
-
[
|
1531 |
-
key in predictions.keys()
|
1532 |
-
for key in ["pred_ids", "label_ids", "predictions"]
|
1533 |
-
]
|
1534 |
-
):
|
1535 |
-
# format is output from self.evaluate_saved_model
|
1536 |
-
predictions_logits = np.array(predictions["predictions"])
|
1537 |
-
true_ids = predictions["label_ids"]
|
1538 |
-
else:
|
1539 |
-
# format is output from self.validate if predict_eval=True
|
1540 |
-
predictions_logits = predictions.predictions
|
1541 |
-
true_ids = predictions.label_ids
|
1542 |
-
|
1543 |
-
num_classes = len(id_class_dict.keys())
|
1544 |
-
num_predict_classes = predictions_logits.shape[1]
|
1545 |
-
assert num_classes == num_predict_classes
|
1546 |
-
classes = id_class_dict.values()
|
1547 |
-
true_labels = [id_class_dict[idx] for idx in true_ids]
|
1548 |
-
predictions_df = pd.DataFrame(predictions_logits, columns=classes)
|
1549 |
-
if custom_class_order is not None:
|
1550 |
-
predictions_df = predictions_df.reindex(columns=custom_class_order)
|
1551 |
-
predictions_df["true"] = true_labels
|
1552 |
-
custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
|
1553 |
-
if custom_class_order is not None:
|
1554 |
-
custom_dict = dict(
|
1555 |
-
zip(custom_class_order, [i for i in range(len(custom_class_order))])
|
1556 |
-
)
|
1557 |
-
predictions_df = predictions_df.sort_values(
|
1558 |
-
by=["true"], key=lambda x: x.map(custom_dict)
|
1559 |
-
)
|
1560 |
-
|
1561 |
-
eu.plot_predictions(
|
1562 |
-
predictions_df, title, output_directory, output_prefix, kwargs_dict
|
1563 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/classifier_utils.py
DELETED
@@ -1,648 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
import random
|
5 |
-
from collections import Counter, defaultdict
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
-
from scipy.stats import chisquare, ranksums
|
10 |
-
from sklearn.metrics import accuracy_score, f1_score
|
11 |
-
from sklearn.model_selection import StratifiedKFold, train_test_split
|
12 |
-
|
13 |
-
from . import perturber_utils as pu
|
14 |
-
|
15 |
-
logger = logging.getLogger(__name__)
|
16 |
-
|
17 |
-
|
18 |
-
def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
|
19 |
-
data = data.shuffle(seed=42)
|
20 |
-
num_cells = len(data)
|
21 |
-
# if max number of cells is defined, then subsample to this max number
|
22 |
-
if max_ncells is not None:
|
23 |
-
if num_cells > max_ncells:
|
24 |
-
data = data.select([i for i in range(max_ncells)])
|
25 |
-
if max_ncells_per_class is not None:
|
26 |
-
class_labels = data[cell_state_dict["state_key"]]
|
27 |
-
random.seed(42)
|
28 |
-
subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
|
29 |
-
data = data.select(subsample_indices)
|
30 |
-
return data
|
31 |
-
|
32 |
-
|
33 |
-
# subsample labels to maximum number N per class and return indices
|
34 |
-
def subsample_by_class(labels, N):
|
35 |
-
label_indices = defaultdict(list)
|
36 |
-
# Gather indices for each label
|
37 |
-
for idx, label in enumerate(labels):
|
38 |
-
label_indices[label].append(idx)
|
39 |
-
selected_indices = []
|
40 |
-
# Select up to N indices for each label
|
41 |
-
for label, indices in label_indices.items():
|
42 |
-
if len(indices) > N:
|
43 |
-
selected_indices.extend(random.sample(indices, N))
|
44 |
-
else:
|
45 |
-
selected_indices.extend(indices)
|
46 |
-
return selected_indices
|
47 |
-
|
48 |
-
|
49 |
-
def rename_cols(data, state_key):
|
50 |
-
data = data.rename_column(state_key, "label")
|
51 |
-
return data
|
52 |
-
|
53 |
-
|
54 |
-
def validate_and_clean_cols(train_data, eval_data, classifier):
|
55 |
-
# validate that data has expected label column and remove others
|
56 |
-
if classifier == "cell":
|
57 |
-
label_col = "label"
|
58 |
-
elif classifier == "gene":
|
59 |
-
label_col = "labels"
|
60 |
-
|
61 |
-
cols_to_keep = [label_col] + ["input_ids", "length"]
|
62 |
-
if label_col not in train_data.column_names:
|
63 |
-
logger.error(f"train_data must contain column {label_col} with class labels.")
|
64 |
-
raise
|
65 |
-
else:
|
66 |
-
train_data = remove_cols(train_data, cols_to_keep)
|
67 |
-
|
68 |
-
if eval_data is not None:
|
69 |
-
if label_col not in eval_data.column_names:
|
70 |
-
logger.error(
|
71 |
-
f"eval_data must contain column {label_col} with class labels."
|
72 |
-
)
|
73 |
-
raise
|
74 |
-
else:
|
75 |
-
eval_data = remove_cols(eval_data, cols_to_keep)
|
76 |
-
return train_data, eval_data
|
77 |
-
|
78 |
-
|
79 |
-
def remove_cols(data, cols_to_keep):
|
80 |
-
other_cols = list(data.features.keys())
|
81 |
-
other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
|
82 |
-
data = data.remove_columns(other_cols)
|
83 |
-
return data
|
84 |
-
|
85 |
-
|
86 |
-
def remove_rare(data, rare_threshold, label, nproc):
|
87 |
-
if rare_threshold > 0:
|
88 |
-
total_cells = len(data)
|
89 |
-
label_counter = Counter(data[label])
|
90 |
-
nonrare_label_dict = {
|
91 |
-
label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
|
92 |
-
}
|
93 |
-
data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
|
94 |
-
return data
|
95 |
-
|
96 |
-
|
97 |
-
def label_classes(classifier, data, gene_class_dict, nproc):
|
98 |
-
if classifier == "cell":
|
99 |
-
label_set = set(data["label"])
|
100 |
-
elif classifier == "gene":
|
101 |
-
# remove cells without any of the target genes
|
102 |
-
def if_contains_label(example):
|
103 |
-
a = pu.flatten_list(gene_class_dict.values())
|
104 |
-
b = example["input_ids"]
|
105 |
-
return not set(a).isdisjoint(b)
|
106 |
-
|
107 |
-
data = data.filter(if_contains_label, num_proc=nproc)
|
108 |
-
label_set = gene_class_dict.keys()
|
109 |
-
|
110 |
-
if len(data) == 0:
|
111 |
-
logger.error(
|
112 |
-
"No cells remain after filtering for target genes. Check target gene list."
|
113 |
-
)
|
114 |
-
raise
|
115 |
-
|
116 |
-
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
117 |
-
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
118 |
-
|
119 |
-
def classes_to_ids(example):
|
120 |
-
if classifier == "cell":
|
121 |
-
example["label"] = class_id_dict[example["label"]]
|
122 |
-
elif classifier == "gene":
|
123 |
-
example["labels"] = label_gene_classes(
|
124 |
-
example, class_id_dict, gene_class_dict
|
125 |
-
)
|
126 |
-
return example
|
127 |
-
|
128 |
-
data = data.map(classes_to_ids, num_proc=nproc)
|
129 |
-
return data, id_class_dict
|
130 |
-
|
131 |
-
|
132 |
-
def label_gene_classes(example, class_id_dict, gene_class_dict):
|
133 |
-
return [
|
134 |
-
class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
|
135 |
-
for token_id in example["input_ids"]
|
136 |
-
]
|
137 |
-
|
138 |
-
|
139 |
-
def prep_gene_classifier_train_eval_split(
|
140 |
-
data,
|
141 |
-
targets,
|
142 |
-
labels,
|
143 |
-
train_index,
|
144 |
-
eval_index,
|
145 |
-
max_ncells,
|
146 |
-
iteration_num,
|
147 |
-
num_proc,
|
148 |
-
balance=False,
|
149 |
-
):
|
150 |
-
# generate cross-validation splits
|
151 |
-
train_data = prep_gene_classifier_split(
|
152 |
-
data,
|
153 |
-
targets,
|
154 |
-
labels,
|
155 |
-
train_index,
|
156 |
-
"train",
|
157 |
-
max_ncells,
|
158 |
-
iteration_num,
|
159 |
-
num_proc,
|
160 |
-
balance,
|
161 |
-
)
|
162 |
-
eval_data = prep_gene_classifier_split(
|
163 |
-
data,
|
164 |
-
targets,
|
165 |
-
labels,
|
166 |
-
eval_index,
|
167 |
-
"eval",
|
168 |
-
max_ncells,
|
169 |
-
iteration_num,
|
170 |
-
num_proc,
|
171 |
-
balance,
|
172 |
-
)
|
173 |
-
return train_data, eval_data
|
174 |
-
|
175 |
-
|
176 |
-
def prep_gene_classifier_split(
|
177 |
-
data,
|
178 |
-
targets,
|
179 |
-
labels,
|
180 |
-
index,
|
181 |
-
subset_name,
|
182 |
-
max_ncells,
|
183 |
-
iteration_num,
|
184 |
-
num_proc,
|
185 |
-
balance=False,
|
186 |
-
):
|
187 |
-
# generate cross-validation splits
|
188 |
-
targets = np.array(targets)
|
189 |
-
labels = np.array(labels)
|
190 |
-
targets_subset = targets[index]
|
191 |
-
labels_subset = labels[index]
|
192 |
-
label_dict_subset = dict(zip(targets_subset, labels_subset))
|
193 |
-
|
194 |
-
# function to filter by whether contains train or eval labels
|
195 |
-
def if_contains_subset_label(example):
|
196 |
-
a = targets_subset
|
197 |
-
b = example["input_ids"]
|
198 |
-
return not set(a).isdisjoint(b)
|
199 |
-
|
200 |
-
# filter dataset for examples containing classes for this split
|
201 |
-
logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
|
202 |
-
subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
|
203 |
-
logger.info(
|
204 |
-
f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
|
205 |
-
)
|
206 |
-
|
207 |
-
# balance gene subsets if train
|
208 |
-
if (subset_name == "train") and (balance is True):
|
209 |
-
subset_data, label_dict_subset = balance_gene_split(
|
210 |
-
subset_data, label_dict_subset, num_proc
|
211 |
-
)
|
212 |
-
|
213 |
-
# subsample to max_ncells
|
214 |
-
subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
|
215 |
-
|
216 |
-
# relabel genes for this split
|
217 |
-
def subset_classes_to_ids(example):
|
218 |
-
example["labels"] = [
|
219 |
-
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
220 |
-
]
|
221 |
-
return example
|
222 |
-
|
223 |
-
subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
|
224 |
-
|
225 |
-
return subset_data
|
226 |
-
|
227 |
-
|
228 |
-
def prep_gene_classifier_all_data(
|
229 |
-
data, targets, labels, max_ncells, num_proc, balance=False
|
230 |
-
):
|
231 |
-
targets = np.array(targets)
|
232 |
-
labels = np.array(labels)
|
233 |
-
label_dict_train = dict(zip(targets, labels))
|
234 |
-
|
235 |
-
# function to filter by whether contains train labels
|
236 |
-
def if_contains_train_label(example):
|
237 |
-
a = targets
|
238 |
-
b = example["input_ids"]
|
239 |
-
return not set(a).isdisjoint(b)
|
240 |
-
|
241 |
-
# filter dataset for examples containing classes for this split
|
242 |
-
logger.info("Filtering training data for genes to classify.")
|
243 |
-
train_data = data.filter(if_contains_train_label, num_proc=num_proc)
|
244 |
-
logger.info(
|
245 |
-
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
246 |
-
)
|
247 |
-
|
248 |
-
if balance is True:
|
249 |
-
train_data, label_dict_train = balance_gene_split(
|
250 |
-
train_data, label_dict_train, num_proc
|
251 |
-
)
|
252 |
-
|
253 |
-
# subsample to max_ncells
|
254 |
-
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
255 |
-
|
256 |
-
# relabel genes for this split
|
257 |
-
def train_classes_to_ids(example):
|
258 |
-
example["labels"] = [
|
259 |
-
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
260 |
-
]
|
261 |
-
return example
|
262 |
-
|
263 |
-
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
|
264 |
-
|
265 |
-
return train_data
|
266 |
-
|
267 |
-
|
268 |
-
def balance_gene_split(subset_data, label_dict_subset, num_proc):
|
269 |
-
# count occurrence of genes in each label category
|
270 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
271 |
-
subset_data, label_dict_subset, num_proc
|
272 |
-
)
|
273 |
-
label_ratio_0to1 = label0_counts / label1_counts
|
274 |
-
|
275 |
-
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
276 |
-
# gene sets already balanced
|
277 |
-
logger.info(
|
278 |
-
"Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
|
279 |
-
)
|
280 |
-
return subset_data, label_dict_subset
|
281 |
-
else:
|
282 |
-
label_ratio_0to1_orig = label_ratio_0to1 + 0
|
283 |
-
label_dict_subset_orig = label_dict_subset.copy()
|
284 |
-
# balance gene sets
|
285 |
-
max_ntrials = 25
|
286 |
-
boost = 1
|
287 |
-
if label_ratio_0to1 > 10 / 8:
|
288 |
-
# downsample label 0
|
289 |
-
for i in range(max_ntrials):
|
290 |
-
label0 = 0
|
291 |
-
label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
|
292 |
-
label0_ngenes = len(label0_genes)
|
293 |
-
label0_nremove = max(
|
294 |
-
1,
|
295 |
-
int(
|
296 |
-
np.floor(
|
297 |
-
label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
|
298 |
-
)
|
299 |
-
),
|
300 |
-
)
|
301 |
-
random.seed(i)
|
302 |
-
label0_remove_genes = random.sample(label0_genes, label0_nremove)
|
303 |
-
label_dict_subset_new = {
|
304 |
-
k: v
|
305 |
-
for k, v in label_dict_subset.items()
|
306 |
-
if k not in label0_remove_genes
|
307 |
-
}
|
308 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
309 |
-
subset_data, label_dict_subset_new, num_proc
|
310 |
-
)
|
311 |
-
label_ratio_0to1 = label0_counts / label1_counts
|
312 |
-
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
313 |
-
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
314 |
-
return filter_data_balanced_genes(
|
315 |
-
subset_data, label_dict_subset_new, num_proc
|
316 |
-
)
|
317 |
-
elif label_ratio_0to1 > 10 / 8:
|
318 |
-
boost = boost * 1.1
|
319 |
-
elif label_ratio_0to1 < 8 / 10:
|
320 |
-
boost = boost * 0.9
|
321 |
-
else:
|
322 |
-
# downsample label 1
|
323 |
-
for i in range(max_ntrials):
|
324 |
-
label1 = 1
|
325 |
-
label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
|
326 |
-
label1_ngenes = len(label1_genes)
|
327 |
-
label1_nremove = max(
|
328 |
-
1,
|
329 |
-
int(
|
330 |
-
np.floor(
|
331 |
-
label1_ngenes
|
332 |
-
- label1_ngenes / ((1 / label_ratio_0to1) * boost)
|
333 |
-
)
|
334 |
-
),
|
335 |
-
)
|
336 |
-
random.seed(i)
|
337 |
-
label1_remove_genes = random.sample(label1_genes, label1_nremove)
|
338 |
-
label_dict_subset_new = {
|
339 |
-
k: v
|
340 |
-
for k, v in label_dict_subset.items()
|
341 |
-
if k not in label1_remove_genes
|
342 |
-
}
|
343 |
-
label0_counts, label1_counts = count_genes_for_balancing(
|
344 |
-
subset_data, label_dict_subset_new, num_proc
|
345 |
-
)
|
346 |
-
label_ratio_0to1 = label0_counts / label1_counts
|
347 |
-
if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
|
348 |
-
# if gene sets now balanced, return new filtered data and new label_dict_subset
|
349 |
-
return filter_data_balanced_genes(
|
350 |
-
subset_data, label_dict_subset_new, num_proc
|
351 |
-
)
|
352 |
-
elif label_ratio_0to1 < 8 / 10:
|
353 |
-
boost = boost * 1.1
|
354 |
-
elif label_ratio_0to1 > 10 / 8:
|
355 |
-
boost = boost * 0.9
|
356 |
-
|
357 |
-
assert i + 1 == max_ntrials
|
358 |
-
if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
|
359 |
-
10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
|
360 |
-
):
|
361 |
-
label_ratio_0to1 = label_ratio_0to1_orig
|
362 |
-
label_dict_subset_new = label_dict_subset_orig
|
363 |
-
logger.warning(
|
364 |
-
f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
|
365 |
-
)
|
366 |
-
return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
|
367 |
-
|
368 |
-
|
369 |
-
def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
|
370 |
-
def count_targets(example):
|
371 |
-
labels = [
|
372 |
-
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
373 |
-
]
|
374 |
-
counter_labels = Counter(labels)
|
375 |
-
# get count of labels 0 or 1, or if absent, return 0
|
376 |
-
example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
|
377 |
-
return example
|
378 |
-
|
379 |
-
subset_data = subset_data.map(count_targets, num_proc=num_proc)
|
380 |
-
|
381 |
-
label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
|
382 |
-
label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
|
383 |
-
|
384 |
-
subset_data = subset_data.remove_columns("labels_counts")
|
385 |
-
|
386 |
-
return label0_counts, label1_counts
|
387 |
-
|
388 |
-
|
389 |
-
def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
|
390 |
-
# function to filter by whether contains labels
|
391 |
-
def if_contains_subset_label(example):
|
392 |
-
a = list(label_dict_subset.keys())
|
393 |
-
b = example["input_ids"]
|
394 |
-
return not set(a).isdisjoint(b)
|
395 |
-
|
396 |
-
# filter dataset for examples containing classes for this split
|
397 |
-
logger.info("Filtering data for balanced genes")
|
398 |
-
subset_data_len_orig = len(subset_data)
|
399 |
-
subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
|
400 |
-
logger.info(
|
401 |
-
f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
|
402 |
-
)
|
403 |
-
|
404 |
-
return subset_data, label_dict_subset
|
405 |
-
|
406 |
-
|
407 |
-
def balance_attr_splits(
|
408 |
-
data,
|
409 |
-
attr_to_split,
|
410 |
-
attr_to_balance,
|
411 |
-
eval_size,
|
412 |
-
max_trials,
|
413 |
-
pval_threshold,
|
414 |
-
state_key,
|
415 |
-
nproc,
|
416 |
-
):
|
417 |
-
metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
|
418 |
-
for attr in attr_to_balance:
|
419 |
-
if attr == state_key:
|
420 |
-
metadata_df[attr] = data["label"]
|
421 |
-
else:
|
422 |
-
metadata_df[attr] = data[attr]
|
423 |
-
metadata_df = metadata_df.drop_duplicates()
|
424 |
-
|
425 |
-
split_attr_ids = list(metadata_df["split_attr_ids"])
|
426 |
-
assert len(split_attr_ids) == len(set(split_attr_ids))
|
427 |
-
eval_num = round(len(split_attr_ids) * eval_size)
|
428 |
-
colnames = (
|
429 |
-
["trial_num", "train_ids", "eval_ids"]
|
430 |
-
+ pu.flatten_list(
|
431 |
-
[
|
432 |
-
[
|
433 |
-
f"{attr}_train_mean_or_counts",
|
434 |
-
f"{attr}_eval_mean_or_counts",
|
435 |
-
f"{attr}_pval",
|
436 |
-
]
|
437 |
-
for attr in attr_to_balance
|
438 |
-
]
|
439 |
-
)
|
440 |
-
+ ["mean_pval"]
|
441 |
-
)
|
442 |
-
balance_df = pd.DataFrame(columns=colnames)
|
443 |
-
data_dict = dict()
|
444 |
-
trial_num = 1
|
445 |
-
for i in range(max_trials):
|
446 |
-
if not all(
|
447 |
-
count > 1 for count in list(Counter(metadata_df[state_key]).values())
|
448 |
-
):
|
449 |
-
logger.error(
|
450 |
-
f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
|
451 |
-
)
|
452 |
-
raise
|
453 |
-
eval_base = []
|
454 |
-
for state in set(metadata_df[state_key]):
|
455 |
-
eval_base += list(
|
456 |
-
metadata_df.loc[
|
457 |
-
metadata_df[state_key][metadata_df[state_key].eq(state)]
|
458 |
-
.sample(1, random_state=i)
|
459 |
-
.index
|
460 |
-
]["split_attr_ids"]
|
461 |
-
)
|
462 |
-
non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
|
463 |
-
random.seed(i)
|
464 |
-
eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
|
465 |
-
train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
|
466 |
-
df_vals = [trial_num, train_ids, eval_ids]
|
467 |
-
pvals = []
|
468 |
-
for attr in attr_to_balance:
|
469 |
-
train_attr = list(
|
470 |
-
metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
|
471 |
-
)
|
472 |
-
eval_attr = list(
|
473 |
-
metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
|
474 |
-
)
|
475 |
-
if attr == state_key:
|
476 |
-
# ensure IDs are interpreted as categorical
|
477 |
-
train_attr = [str(item) for item in train_attr]
|
478 |
-
eval_attr = [str(item) for item in eval_attr]
|
479 |
-
if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
|
480 |
-
train_attr_mean = np.nanmean(train_attr)
|
481 |
-
eval_attr_mean = np.nanmean(eval_attr)
|
482 |
-
pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
|
483 |
-
df_vals += [train_attr_mean, eval_attr_mean, pval]
|
484 |
-
elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
|
485 |
-
obs_counts = Counter(train_attr)
|
486 |
-
exp_counts = Counter(eval_attr)
|
487 |
-
all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
|
488 |
-
obs = [obs_counts[cat] for cat in all_categ]
|
489 |
-
exp = [
|
490 |
-
exp_counts[cat] * sum(obs) / sum(exp_counts.values())
|
491 |
-
for cat in all_categ
|
492 |
-
]
|
493 |
-
pval = chisquare(f_obs=obs, f_exp=exp).pvalue
|
494 |
-
train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
|
495 |
-
eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
|
496 |
-
df_vals += [train_attr_counts, eval_attr_counts, pval]
|
497 |
-
else:
|
498 |
-
logger.error(
|
499 |
-
f"Inconsistent data types in attribute {attr}. "
|
500 |
-
"Cannot infer if continuous or categorical. "
|
501 |
-
"Must be all numeric (continuous) or all strings (categorical) to balance."
|
502 |
-
)
|
503 |
-
raise
|
504 |
-
pvals += [pval]
|
505 |
-
|
506 |
-
df_vals += [np.nanmean(pvals)]
|
507 |
-
balance_df_i = pd.DataFrame(df_vals, index=colnames).T
|
508 |
-
balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
|
509 |
-
valid_pvals = [
|
510 |
-
pval_i
|
511 |
-
for pval_i in pvals
|
512 |
-
if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
|
513 |
-
]
|
514 |
-
if all(i >= pval_threshold for i in valid_pvals):
|
515 |
-
data_dict["train"] = pu.filter_by_dict(
|
516 |
-
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
|
517 |
-
)
|
518 |
-
data_dict["test"] = pu.filter_by_dict(
|
519 |
-
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
|
520 |
-
)
|
521 |
-
return data_dict, balance_df
|
522 |
-
trial_num = trial_num + 1
|
523 |
-
balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
|
524 |
-
data_dict["train"] = pu.filter_by_dict(
|
525 |
-
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
|
526 |
-
)
|
527 |
-
data_dict["test"] = pu.filter_by_dict(
|
528 |
-
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
|
529 |
-
)
|
530 |
-
logger.warning(
|
531 |
-
f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
|
532 |
-
f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
|
533 |
-
)
|
534 |
-
return data_dict, balance_df
|
535 |
-
|
536 |
-
|
537 |
-
def get_num_classes(id_class_dict):
|
538 |
-
return len(set(id_class_dict.values()))
|
539 |
-
|
540 |
-
|
541 |
-
def compute_metrics(pred):
|
542 |
-
labels = pred.label_ids
|
543 |
-
preds = pred.predictions.argmax(-1)
|
544 |
-
|
545 |
-
# calculate accuracy and macro f1 using sklearn's function
|
546 |
-
if len(labels.shape) == 1:
|
547 |
-
acc = accuracy_score(labels, preds)
|
548 |
-
macro_f1 = f1_score(labels, preds, average="macro")
|
549 |
-
else:
|
550 |
-
flat_labels = labels.flatten().tolist()
|
551 |
-
flat_preds = preds.flatten().tolist()
|
552 |
-
logit_label_paired = [
|
553 |
-
item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
|
554 |
-
]
|
555 |
-
y_pred = [item[0] for item in logit_label_paired]
|
556 |
-
y_true = [item[1] for item in logit_label_paired]
|
557 |
-
|
558 |
-
acc = accuracy_score(y_true, y_pred)
|
559 |
-
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
560 |
-
|
561 |
-
return {"accuracy": acc, "macro_f1": macro_f1}
|
562 |
-
|
563 |
-
|
564 |
-
def get_default_train_args(model, classifier, data, output_dir):
|
565 |
-
num_layers = pu.quant_layers(model)
|
566 |
-
freeze_layers = 0
|
567 |
-
batch_size = 12
|
568 |
-
if classifier == "cell":
|
569 |
-
epochs = 10
|
570 |
-
evaluation_strategy = "epoch"
|
571 |
-
load_best_model_at_end = True
|
572 |
-
else:
|
573 |
-
epochs = 1
|
574 |
-
evaluation_strategy = "no"
|
575 |
-
load_best_model_at_end = False
|
576 |
-
|
577 |
-
if num_layers == 6:
|
578 |
-
default_training_args = {
|
579 |
-
"learning_rate": 5e-5,
|
580 |
-
"lr_scheduler_type": "linear",
|
581 |
-
"warmup_steps": 500,
|
582 |
-
"per_device_train_batch_size": batch_size,
|
583 |
-
"per_device_eval_batch_size": batch_size,
|
584 |
-
}
|
585 |
-
else:
|
586 |
-
default_training_args = {
|
587 |
-
"per_device_train_batch_size": batch_size,
|
588 |
-
"per_device_eval_batch_size": batch_size,
|
589 |
-
}
|
590 |
-
|
591 |
-
training_args = {
|
592 |
-
"num_train_epochs": epochs,
|
593 |
-
"do_train": True,
|
594 |
-
"do_eval": True,
|
595 |
-
"evaluation_strategy": evaluation_strategy,
|
596 |
-
"logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
|
597 |
-
"save_strategy": "epoch",
|
598 |
-
"group_by_length": False,
|
599 |
-
"length_column_name": "length",
|
600 |
-
"disable_tqdm": False,
|
601 |
-
"weight_decay": 0.001,
|
602 |
-
"load_best_model_at_end": load_best_model_at_end,
|
603 |
-
}
|
604 |
-
training_args.update(default_training_args)
|
605 |
-
|
606 |
-
return training_args, freeze_layers
|
607 |
-
|
608 |
-
|
609 |
-
def load_best_model(directory, model_type, num_classes, mode="eval"):
|
610 |
-
file_dict = dict()
|
611 |
-
for subdir, dirs, files in os.walk(directory):
|
612 |
-
for file in files:
|
613 |
-
if file.endswith("result.json"):
|
614 |
-
with open(f"{subdir}/{file}", "rb") as fp:
|
615 |
-
result_json = json.load(fp)
|
616 |
-
file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
|
617 |
-
file_df = pd.DataFrame(
|
618 |
-
{"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
|
619 |
-
)
|
620 |
-
model_superdir = (
|
621 |
-
"run-"
|
622 |
-
+ file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
|
623 |
-
.split("_objective_")[2]
|
624 |
-
.split("_")[0]
|
625 |
-
)
|
626 |
-
|
627 |
-
for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
|
628 |
-
for file in files:
|
629 |
-
if file.endswith("model.safetensors"):
|
630 |
-
model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
|
631 |
-
return model
|
632 |
-
|
633 |
-
|
634 |
-
class StratifiedKFold3(StratifiedKFold):
|
635 |
-
def split(self, targets, labels, test_ratio=0.5, groups=None):
|
636 |
-
s = super().split(targets, labels, groups)
|
637 |
-
for train_indxs, test_indxs in s:
|
638 |
-
if test_ratio == 0:
|
639 |
-
yield train_indxs, test_indxs, None
|
640 |
-
else:
|
641 |
-
labels_test = np.array(labels)[test_indxs]
|
642 |
-
valid_indxs, test_indxs = train_test_split(
|
643 |
-
test_indxs,
|
644 |
-
stratify=labels_test,
|
645 |
-
test_size=test_ratio,
|
646 |
-
random_state=0,
|
647 |
-
)
|
648 |
-
yield train_indxs, valid_indxs, test_indxs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/collator_for_classification.py
CHANGED
@@ -1,22 +1,24 @@
|
|
1 |
"""
|
2 |
Geneformer collator for gene and cell classification.
|
|
|
3 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
4 |
"""
|
5 |
-
|
|
|
6 |
import warnings
|
7 |
from enum import Enum
|
8 |
from typing import Dict, List, Optional, Union
|
9 |
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
from transformers import (
|
13 |
-
BatchEncoding,
|
14 |
DataCollatorForTokenClassification,
|
15 |
SpecialTokensMixin,
|
|
|
16 |
)
|
17 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
|
|
|
|
|
20 |
EncodedInput = List[int]
|
21 |
logger = logging.get_logger(__name__)
|
22 |
VERY_LARGE_INTEGER = int(
|
@@ -28,7 +30,6 @@ LARGE_INTEGER = int(
|
|
28 |
|
29 |
# precollator functions
|
30 |
|
31 |
-
|
32 |
class ExplicitEnum(Enum):
|
33 |
"""
|
34 |
Enum with more explicit error message for missing values.
|
@@ -41,7 +42,6 @@ class ExplicitEnum(Enum):
|
|
41 |
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
42 |
)
|
43 |
|
44 |
-
|
45 |
class TruncationStrategy(ExplicitEnum):
|
46 |
"""
|
47 |
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
@@ -54,6 +54,7 @@ class TruncationStrategy(ExplicitEnum):
|
|
54 |
DO_NOT_TRUNCATE = "do_not_truncate"
|
55 |
|
56 |
|
|
|
57 |
class PaddingStrategy(ExplicitEnum):
|
58 |
"""
|
59 |
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
@@ -65,6 +66,7 @@ class PaddingStrategy(ExplicitEnum):
|
|
65 |
DO_NOT_PAD = "do_not_pad"
|
66 |
|
67 |
|
|
|
68 |
class TensorType(ExplicitEnum):
|
69 |
"""
|
70 |
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
@@ -76,41 +78,21 @@ class TensorType(ExplicitEnum):
|
|
76 |
NUMPY = "np"
|
77 |
JAX = "jax"
|
78 |
|
79 |
-
|
80 |
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
self.token_dictionary.get("<pad>"),
|
92 |
-
]
|
93 |
-
|
94 |
-
@property
|
95 |
-
def all_special_ids(self):
|
96 |
-
return self._all_special_ids
|
97 |
-
|
98 |
-
@property
|
99 |
-
def mask_token_id(self):
|
100 |
-
return self._mask_token_id
|
101 |
-
|
102 |
-
@property
|
103 |
-
def pad_token_id(self):
|
104 |
-
return self._pad_token_id
|
105 |
|
106 |
def _get_padding_truncation_strategies(
|
107 |
-
self,
|
108 |
-
padding=True,
|
109 |
-
truncation=False,
|
110 |
-
max_length=None,
|
111 |
-
pad_to_multiple_of=None,
|
112 |
-
verbose=True,
|
113 |
-
**kwargs,
|
114 |
):
|
115 |
"""
|
116 |
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
@@ -123,9 +105,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
123 |
# If you only set max_length, it activates truncation for max_length
|
124 |
if max_length is not None and padding is False and truncation is False:
|
125 |
if verbose:
|
126 |
-
if not self.deprecation_warnings.get(
|
127 |
-
"Truncation-not-explicitly-activated", False
|
128 |
-
):
|
129 |
logger.warning(
|
130 |
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
131 |
"please use `truncation=True` to explicitly truncate examples to max length. "
|
@@ -153,9 +133,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
153 |
padding_strategy = PaddingStrategy.MAX_LENGTH
|
154 |
elif padding is not False:
|
155 |
if padding is True:
|
156 |
-
padding_strategy =
|
157 |
-
PaddingStrategy.LONGEST
|
158 |
-
) # Default to pad to the longest sequence in the batch
|
159 |
elif not isinstance(padding, PaddingStrategy):
|
160 |
padding_strategy = PaddingStrategy(padding)
|
161 |
elif isinstance(padding, PaddingStrategy):
|
@@ -195,9 +173,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
195 |
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
196 |
if self.model_max_length > LARGE_INTEGER:
|
197 |
if verbose:
|
198 |
-
if not self.deprecation_warnings.get(
|
199 |
-
"Asking-to-pad-to-max_length", False
|
200 |
-
):
|
201 |
logger.warning(
|
202 |
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
203 |
"Default to no padding."
|
@@ -210,24 +186,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
210 |
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
211 |
if self.model_max_length > LARGE_INTEGER:
|
212 |
if verbose:
|
213 |
-
if not self.deprecation_warnings.get(
|
214 |
-
"Asking-to-truncate-to-max_length", False
|
215 |
-
):
|
216 |
logger.warning(
|
217 |
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
218 |
"Default to no truncation."
|
219 |
)
|
220 |
-
self.deprecation_warnings[
|
221 |
-
"Asking-to-truncate-to-max_length"
|
222 |
-
] = True
|
223 |
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
224 |
else:
|
225 |
max_length = self.model_max_length
|
226 |
|
227 |
# Test if we have a padding token
|
228 |
-
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
229 |
-
not self.pad_token or self.pad_token_id < 0
|
230 |
-
):
|
231 |
raise ValueError(
|
232 |
"Asking to pad but the tokenizer does not have a padding token. "
|
233 |
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
@@ -258,7 +228,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
258 |
Dict[str, List[EncodedInput]],
|
259 |
List[Dict[str, EncodedInput]],
|
260 |
],
|
261 |
-
class_type,
|
262 |
padding: Union[bool, str, PaddingStrategy] = True,
|
263 |
max_length: Optional[int] = None,
|
264 |
pad_to_multiple_of: Optional[int] = None,
|
@@ -269,23 +239,29 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
269 |
"""
|
270 |
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
271 |
in the batch.
|
|
|
272 |
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
273 |
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
|
|
274 |
.. note::
|
|
|
275 |
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
276 |
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
277 |
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
|
|
278 |
Args:
|
279 |
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
280 |
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
281 |
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
282 |
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
283 |
well as in a PyTorch Dataloader collate function.
|
|
|
284 |
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
285 |
see the note above for the return type.
|
286 |
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
287 |
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
288 |
index) among:
|
|
|
289 |
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
290 |
single sequence if provided).
|
291 |
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
@@ -296,14 +272,17 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
296 |
Maximum length of the returned list and optionally padding length (see above).
|
297 |
pad_to_multiple_of (:obj:`int`, `optional`):
|
298 |
If set will pad the sequence to a multiple of the provided value.
|
|
|
299 |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
300 |
>= 7.5 (Volta).
|
301 |
return_attention_mask (:obj:`bool`, `optional`):
|
302 |
Whether to return the attention mask. If left to the default, will return the attention mask according
|
303 |
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
|
|
304 |
`What are attention masks? <../glossary.html#attention-mask>`__
|
305 |
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
306 |
If set, will return tensors instead of list of python integers. Acceptable values are:
|
|
|
307 |
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
308 |
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
309 |
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
@@ -312,13 +291,8 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
312 |
"""
|
313 |
# If we have a list of dicts, let's convert it in a dict of lists
|
314 |
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
315 |
-
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
316 |
-
encoded_inputs[0]
|
317 |
-
):
|
318 |
-
encoded_inputs = {
|
319 |
-
key: [example[key] for example in encoded_inputs]
|
320 |
-
for key in encoded_inputs[0].keys()
|
321 |
-
}
|
322 |
|
323 |
# The model's main input name, usually `input_ids`, has be passed for padding
|
324 |
if self.model_input_names[0] not in encoded_inputs:
|
@@ -412,7 +386,7 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
412 |
def _pad(
|
413 |
self,
|
414 |
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
415 |
-
class_type,
|
416 |
max_length: Optional[int] = None,
|
417 |
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
418 |
pad_to_multiple_of: Optional[int] = None,
|
@@ -420,15 +394,18 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
420 |
) -> dict:
|
421 |
"""
|
422 |
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
|
|
423 |
Args:
|
424 |
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
425 |
max_length: maximum length of the returned list and optionally padding length (see below).
|
426 |
Will truncate by taking into account the special tokens.
|
427 |
padding_strategy: PaddingStrategy to use for padding.
|
|
|
428 |
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
429 |
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
430 |
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
431 |
The tokenizer padding sides are defined in self.padding_side:
|
|
|
432 |
- 'left': pads on the left of the sequences
|
433 |
- 'right': pads on the right of the sequences
|
434 |
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
@@ -445,73 +422,46 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
445 |
if padding_strategy == PaddingStrategy.LONGEST:
|
446 |
max_length = len(required_input)
|
447 |
|
448 |
-
if (
|
449 |
-
max_length is not None
|
450 |
-
and pad_to_multiple_of is not None
|
451 |
-
and (max_length % pad_to_multiple_of != 0)
|
452 |
-
):
|
453 |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
454 |
|
455 |
-
needs_to_be_padded = (
|
456 |
-
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
457 |
-
and len(required_input) != max_length
|
458 |
-
)
|
459 |
|
460 |
if needs_to_be_padded:
|
461 |
difference = max_length - len(required_input)
|
462 |
if self.padding_side == "right":
|
463 |
if return_attention_mask:
|
464 |
-
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
465 |
-
0
|
466 |
-
] * difference
|
467 |
if "token_type_ids" in encoded_inputs:
|
468 |
encoded_inputs["token_type_ids"] = (
|
469 |
-
encoded_inputs["token_type_ids"]
|
470 |
-
+ [self.pad_token_type_id] * difference
|
471 |
)
|
472 |
if "special_tokens_mask" in encoded_inputs:
|
473 |
-
encoded_inputs["special_tokens_mask"] =
|
474 |
-
|
475 |
-
)
|
476 |
-
encoded_inputs[self.model_input_names[0]] = (
|
477 |
-
required_input + [self.pad_token_id] * difference
|
478 |
-
)
|
479 |
if class_type == "gene":
|
480 |
-
encoded_inputs["labels"] =
|
481 |
-
encoded_inputs["labels"] + [-100] * difference
|
482 |
-
)
|
483 |
elif self.padding_side == "left":
|
484 |
if return_attention_mask:
|
485 |
-
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
486 |
-
required_input
|
487 |
-
)
|
488 |
if "token_type_ids" in encoded_inputs:
|
489 |
-
encoded_inputs["token_type_ids"] = [
|
490 |
-
|
491 |
-
]
|
492 |
if "special_tokens_mask" in encoded_inputs:
|
493 |
-
encoded_inputs["special_tokens_mask"] = [
|
494 |
-
|
495 |
-
] * difference + encoded_inputs["special_tokens_mask"]
|
496 |
-
encoded_inputs[self.model_input_names[0]] = [
|
497 |
-
self.pad_token_id
|
498 |
-
] * difference + required_input
|
499 |
if class_type == "gene":
|
500 |
-
encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
|
501 |
-
"labels"
|
502 |
-
]
|
503 |
else:
|
504 |
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
505 |
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
506 |
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
507 |
-
|
508 |
return encoded_inputs
|
509 |
|
510 |
def get_special_tokens_mask(
|
511 |
-
self,
|
512 |
-
token_ids_0: List[int],
|
513 |
-
token_ids_1: Optional[List[int]] = None,
|
514 |
-
already_has_special_tokens: bool = False,
|
515 |
) -> List[int]:
|
516 |
"""
|
517 |
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
@@ -535,15 +485,11 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
535 |
|
536 |
all_special_ids = self.all_special_ids # cache the property
|
537 |
|
538 |
-
special_tokens_mask = [
|
539 |
-
1 if token in all_special_ids else 0 for token in token_ids_0
|
540 |
-
]
|
541 |
|
542 |
return special_tokens_mask
|
543 |
|
544 |
-
def convert_tokens_to_ids(
|
545 |
-
self, tokens: Union[str, List[str]]
|
546 |
-
) -> Union[int, List[int]]:
|
547 |
"""
|
548 |
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
549 |
vocabulary.
|
@@ -567,15 +513,14 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
567 |
if token is None:
|
568 |
return None
|
569 |
|
570 |
-
return
|
571 |
|
572 |
def __len__(self):
|
573 |
-
return len(
|
574 |
|
575 |
|
576 |
# collator functions
|
577 |
|
578 |
-
|
579 |
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
580 |
"""
|
581 |
Data collator that will dynamically pad the inputs received, as well as the labels.
|
@@ -601,33 +546,25 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
601 |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
602 |
"""
|
603 |
|
|
|
604 |
class_type = "gene"
|
605 |
padding: Union[bool, str, PaddingStrategy] = True
|
606 |
max_length: Optional[int] = None
|
607 |
pad_to_multiple_of: Optional[int] = None
|
608 |
label_pad_token_id: int = -100
|
609 |
-
|
610 |
def __init__(self, *args, **kwargs) -> None:
|
611 |
-
self.token_dictionary = kwargs.pop("token_dictionary")
|
612 |
super().__init__(
|
613 |
-
tokenizer=
|
614 |
-
token_dictionary=self.token_dictionary
|
615 |
-
),
|
616 |
padding=self.padding,
|
617 |
max_length=self.max_length,
|
618 |
pad_to_multiple_of=self.pad_to_multiple_of,
|
619 |
label_pad_token_id=self.label_pad_token_id,
|
620 |
-
*args,
|
621 |
-
**kwargs,
|
622 |
-
)
|
623 |
|
624 |
def _prepare_batch(self, features):
|
625 |
label_name = "label" if "label" in features[0].keys() else "labels"
|
626 |
-
labels = (
|
627 |
-
[feature[label_name] for feature in features]
|
628 |
-
if label_name in features[0].keys()
|
629 |
-
else None
|
630 |
-
)
|
631 |
batch = self.tokenizer.pad(
|
632 |
features,
|
633 |
class_type=self.class_type,
|
@@ -637,31 +574,29 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
637 |
return_tensors="pt",
|
638 |
)
|
639 |
return batch
|
640 |
-
|
641 |
def __call__(self, features):
|
642 |
batch = self._prepare_batch(features)
|
643 |
|
644 |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
645 |
return batch
|
646 |
|
647 |
-
|
648 |
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
|
|
649 |
class_type = "cell"
|
650 |
|
651 |
def _prepare_batch(self, features):
|
|
|
652 |
batch = super()._prepare_batch(features)
|
653 |
-
|
654 |
# Special handling for labels.
|
655 |
# Ensure that tensor is created with the correct type
|
656 |
# (it should be automatically the case, but let's make sure of it.)
|
657 |
first = features[0]
|
658 |
if "label" in first and first["label"] is not None:
|
659 |
-
label = (
|
660 |
-
first["label"].item()
|
661 |
-
if isinstance(first["label"], torch.Tensor)
|
662 |
-
else first["label"]
|
663 |
-
)
|
664 |
dtype = torch.long if isinstance(label, int) else torch.float
|
665 |
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
666 |
-
|
667 |
return batch
|
|
|
1 |
"""
|
2 |
Geneformer collator for gene and cell classification.
|
3 |
+
|
4 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
import warnings
|
9 |
from enum import Enum
|
10 |
from typing import Dict, List, Optional, Union
|
11 |
|
|
|
|
|
12 |
from transformers import (
|
|
|
13 |
DataCollatorForTokenClassification,
|
14 |
SpecialTokensMixin,
|
15 |
+
BatchEncoding,
|
16 |
)
|
17 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
|
20 |
+
from .pretrainer import token_dictionary
|
21 |
+
|
22 |
EncodedInput = List[int]
|
23 |
logger = logging.get_logger(__name__)
|
24 |
VERY_LARGE_INTEGER = int(
|
|
|
30 |
|
31 |
# precollator functions
|
32 |
|
|
|
33 |
class ExplicitEnum(Enum):
|
34 |
"""
|
35 |
Enum with more explicit error message for missing values.
|
|
|
42 |
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
43 |
)
|
44 |
|
|
|
45 |
class TruncationStrategy(ExplicitEnum):
|
46 |
"""
|
47 |
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
|
|
54 |
DO_NOT_TRUNCATE = "do_not_truncate"
|
55 |
|
56 |
|
57 |
+
|
58 |
class PaddingStrategy(ExplicitEnum):
|
59 |
"""
|
60 |
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
|
|
66 |
DO_NOT_PAD = "do_not_pad"
|
67 |
|
68 |
|
69 |
+
|
70 |
class TensorType(ExplicitEnum):
|
71 |
"""
|
72 |
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
|
|
78 |
NUMPY = "np"
|
79 |
JAX = "jax"
|
80 |
|
81 |
+
|
82 |
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
83 |
+
mask_token = "<mask>"
|
84 |
+
mask_token_id = token_dictionary.get("<mask>")
|
85 |
+
pad_token = "<pad>"
|
86 |
+
pad_token_id = token_dictionary.get("<pad>")
|
87 |
+
padding_side = "right"
|
88 |
+
all_special_ids = [
|
89 |
+
token_dictionary.get("<mask>"),
|
90 |
+
token_dictionary.get("<pad>")
|
91 |
+
]
|
92 |
+
model_input_names = ["input_ids"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
def _get_padding_truncation_strategies(
|
95 |
+
self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
):
|
97 |
"""
|
98 |
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
|
|
105 |
# If you only set max_length, it activates truncation for max_length
|
106 |
if max_length is not None and padding is False and truncation is False:
|
107 |
if verbose:
|
108 |
+
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
|
|
|
|
|
109 |
logger.warning(
|
110 |
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
111 |
"please use `truncation=True` to explicitly truncate examples to max length. "
|
|
|
133 |
padding_strategy = PaddingStrategy.MAX_LENGTH
|
134 |
elif padding is not False:
|
135 |
if padding is True:
|
136 |
+
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
|
|
|
|
137 |
elif not isinstance(padding, PaddingStrategy):
|
138 |
padding_strategy = PaddingStrategy(padding)
|
139 |
elif isinstance(padding, PaddingStrategy):
|
|
|
173 |
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
174 |
if self.model_max_length > LARGE_INTEGER:
|
175 |
if verbose:
|
176 |
+
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
|
|
|
|
|
177 |
logger.warning(
|
178 |
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
179 |
"Default to no padding."
|
|
|
186 |
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
187 |
if self.model_max_length > LARGE_INTEGER:
|
188 |
if verbose:
|
189 |
+
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
|
|
|
|
|
190 |
logger.warning(
|
191 |
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
192 |
"Default to no truncation."
|
193 |
)
|
194 |
+
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
|
|
|
|
|
195 |
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
196 |
else:
|
197 |
max_length = self.model_max_length
|
198 |
|
199 |
# Test if we have a padding token
|
200 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
|
|
|
|
|
201 |
raise ValueError(
|
202 |
"Asking to pad but the tokenizer does not have a padding token. "
|
203 |
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
|
|
228 |
Dict[str, List[EncodedInput]],
|
229 |
List[Dict[str, EncodedInput]],
|
230 |
],
|
231 |
+
class_type, # options: "gene" or "cell"
|
232 |
padding: Union[bool, str, PaddingStrategy] = True,
|
233 |
max_length: Optional[int] = None,
|
234 |
pad_to_multiple_of: Optional[int] = None,
|
|
|
239 |
"""
|
240 |
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
241 |
in the batch.
|
242 |
+
|
243 |
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
244 |
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
245 |
+
|
246 |
.. note::
|
247 |
+
|
248 |
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
249 |
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
250 |
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
251 |
+
|
252 |
Args:
|
253 |
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
254 |
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
255 |
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
256 |
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
257 |
well as in a PyTorch Dataloader collate function.
|
258 |
+
|
259 |
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
260 |
see the note above for the return type.
|
261 |
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
262 |
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
263 |
index) among:
|
264 |
+
|
265 |
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
266 |
single sequence if provided).
|
267 |
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
|
272 |
Maximum length of the returned list and optionally padding length (see above).
|
273 |
pad_to_multiple_of (:obj:`int`, `optional`):
|
274 |
If set will pad the sequence to a multiple of the provided value.
|
275 |
+
|
276 |
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
277 |
>= 7.5 (Volta).
|
278 |
return_attention_mask (:obj:`bool`, `optional`):
|
279 |
Whether to return the attention mask. If left to the default, will return the attention mask according
|
280 |
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
281 |
+
|
282 |
`What are attention masks? <../glossary.html#attention-mask>`__
|
283 |
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
284 |
If set, will return tensors instead of list of python integers. Acceptable values are:
|
285 |
+
|
286 |
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
287 |
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
288 |
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
|
|
291 |
"""
|
292 |
# If we have a list of dicts, let's convert it in a dict of lists
|
293 |
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
294 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
295 |
+
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
# The model's main input name, usually `input_ids`, has be passed for padding
|
298 |
if self.model_input_names[0] not in encoded_inputs:
|
|
|
386 |
def _pad(
|
387 |
self,
|
388 |
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
389 |
+
class_type, # options: "gene" or "cell"
|
390 |
max_length: Optional[int] = None,
|
391 |
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
392 |
pad_to_multiple_of: Optional[int] = None,
|
|
|
394 |
) -> dict:
|
395 |
"""
|
396 |
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
397 |
+
|
398 |
Args:
|
399 |
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
400 |
max_length: maximum length of the returned list and optionally padding length (see below).
|
401 |
Will truncate by taking into account the special tokens.
|
402 |
padding_strategy: PaddingStrategy to use for padding.
|
403 |
+
|
404 |
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
405 |
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
406 |
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
407 |
The tokenizer padding sides are defined in self.padding_side:
|
408 |
+
|
409 |
- 'left': pads on the left of the sequences
|
410 |
- 'right': pads on the right of the sequences
|
411 |
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
|
|
422 |
if padding_strategy == PaddingStrategy.LONGEST:
|
423 |
max_length = len(required_input)
|
424 |
|
425 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
|
|
|
|
|
|
|
|
426 |
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
427 |
|
428 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
|
|
|
|
|
|
429 |
|
430 |
if needs_to_be_padded:
|
431 |
difference = max_length - len(required_input)
|
432 |
if self.padding_side == "right":
|
433 |
if return_attention_mask:
|
434 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
|
|
|
|
435 |
if "token_type_ids" in encoded_inputs:
|
436 |
encoded_inputs["token_type_ids"] = (
|
437 |
+
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
|
|
438 |
)
|
439 |
if "special_tokens_mask" in encoded_inputs:
|
440 |
+
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
441 |
+
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
|
|
|
|
|
|
|
|
442 |
if class_type == "gene":
|
443 |
+
encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
|
|
|
|
|
444 |
elif self.padding_side == "left":
|
445 |
if return_attention_mask:
|
446 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
|
|
|
|
447 |
if "token_type_ids" in encoded_inputs:
|
448 |
+
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
449 |
+
"token_type_ids"
|
450 |
+
]
|
451 |
if "special_tokens_mask" in encoded_inputs:
|
452 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
453 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
|
|
|
|
|
|
|
|
454 |
if class_type == "gene":
|
455 |
+
encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
|
|
|
|
|
456 |
else:
|
457 |
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
458 |
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
459 |
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
460 |
+
|
461 |
return encoded_inputs
|
462 |
|
463 |
def get_special_tokens_mask(
|
464 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
|
|
|
|
|
|
465 |
) -> List[int]:
|
466 |
"""
|
467 |
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
|
|
485 |
|
486 |
all_special_ids = self.all_special_ids # cache the property
|
487 |
|
488 |
+
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
|
|
|
|
|
489 |
|
490 |
return special_tokens_mask
|
491 |
|
492 |
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
|
|
|
|
493 |
"""
|
494 |
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
495 |
vocabulary.
|
|
|
513 |
if token is None:
|
514 |
return None
|
515 |
|
516 |
+
return token_dictionary.get(token)
|
517 |
|
518 |
def __len__(self):
|
519 |
+
return len(token_dictionary)
|
520 |
|
521 |
|
522 |
# collator functions
|
523 |
|
|
|
524 |
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
525 |
"""
|
526 |
Data collator that will dynamically pad the inputs received, as well as the labels.
|
|
|
546 |
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
547 |
"""
|
548 |
|
549 |
+
tokenizer = PrecollatorForGeneAndCellClassification()
|
550 |
class_type = "gene"
|
551 |
padding: Union[bool, str, PaddingStrategy] = True
|
552 |
max_length: Optional[int] = None
|
553 |
pad_to_multiple_of: Optional[int] = None
|
554 |
label_pad_token_id: int = -100
|
555 |
+
|
556 |
def __init__(self, *args, **kwargs) -> None:
|
|
|
557 |
super().__init__(
|
558 |
+
tokenizer=self.tokenizer,
|
|
|
|
|
559 |
padding=self.padding,
|
560 |
max_length=self.max_length,
|
561 |
pad_to_multiple_of=self.pad_to_multiple_of,
|
562 |
label_pad_token_id=self.label_pad_token_id,
|
563 |
+
*args, **kwargs)
|
|
|
|
|
564 |
|
565 |
def _prepare_batch(self, features):
|
566 |
label_name = "label" if "label" in features[0].keys() else "labels"
|
567 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
|
|
|
|
|
|
|
|
568 |
batch = self.tokenizer.pad(
|
569 |
features,
|
570 |
class_type=self.class_type,
|
|
|
574 |
return_tensors="pt",
|
575 |
)
|
576 |
return batch
|
577 |
+
|
578 |
def __call__(self, features):
|
579 |
batch = self._prepare_batch(features)
|
580 |
|
581 |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
582 |
return batch
|
583 |
|
584 |
+
|
585 |
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
586 |
+
|
587 |
class_type = "cell"
|
588 |
|
589 |
def _prepare_batch(self, features):
|
590 |
+
|
591 |
batch = super()._prepare_batch(features)
|
592 |
+
|
593 |
# Special handling for labels.
|
594 |
# Ensure that tensor is created with the correct type
|
595 |
# (it should be automatically the case, but let's make sure of it.)
|
596 |
first = features[0]
|
597 |
if "label" in first and first["label"] is not None:
|
598 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
|
|
|
|
|
|
|
|
599 |
dtype = torch.long if isinstance(label, int) else torch.float
|
600 |
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
601 |
+
|
602 |
return batch
|
geneformer/emb_extractor.py
CHANGED
@@ -1,419 +1,253 @@
|
|
1 |
"""
|
2 |
Geneformer embedding extractor.
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
|
12 |
# imports
|
13 |
import logging
|
14 |
-
import pickle
|
15 |
-
from collections import Counter
|
16 |
-
from pathlib import Path
|
17 |
-
|
18 |
import anndata
|
19 |
import matplotlib.pyplot as plt
|
|
|
20 |
import pandas as pd
|
|
|
|
|
21 |
import scanpy as sc
|
22 |
import seaborn as sns
|
23 |
import torch
|
24 |
-
from
|
25 |
-
from
|
|
|
|
|
26 |
|
27 |
-
from . import TOKEN_DICTIONARY_FILE
|
28 |
-
from . import perturber_utils as pu
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
32 |
|
33 |
# extract embeddings
|
34 |
-
def get_embs(
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
summary_stat=None,
|
44 |
-
silent=False,
|
45 |
-
):
|
46 |
-
model_input_size = pu.get_model_input_size(model)
|
47 |
total_batch_length = len(filtered_input_data)
|
48 |
-
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
gene_set = list(
|
59 |
-
{
|
60 |
-
element
|
61 |
-
for sublist in filtered_input_data["input_ids"]
|
62 |
-
for element in sublist
|
63 |
-
}
|
64 |
-
)
|
65 |
-
# initiate dict with genes as keys and tdigests for # of emb dims as values
|
66 |
-
embs_tdigests_dict = {
|
67 |
-
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
-
}
|
69 |
-
|
70 |
-
# Check if CLS and EOS token is present in the token dictionary
|
71 |
-
cls_present = any("<cls>" in value for value in token_gene_dict.values())
|
72 |
-
eos_present = any("<eos>" in value for value in token_gene_dict.values())
|
73 |
-
if emb_mode == "cls":
|
74 |
-
assert cls_present, "<cls> token missing in token dictionary"
|
75 |
-
# Check to make sure that the first token of the filtered input data is cls token
|
76 |
-
gene_token_dict = {v: k for k, v in token_gene_dict.items()}
|
77 |
-
cls_token_id = gene_token_dict["<cls>"]
|
78 |
-
assert (
|
79 |
-
filtered_input_data["input_ids"][0][0] == cls_token_id
|
80 |
-
), "First token is not <cls> token value"
|
81 |
-
elif emb_mode == "cell":
|
82 |
-
if cls_present:
|
83 |
-
logger.warning(
|
84 |
-
"CLS token present in token dictionary, excluding from average."
|
85 |
-
)
|
86 |
-
if eos_present:
|
87 |
-
logger.warning(
|
88 |
-
"EOS token present in token dictionary, excluding from average."
|
89 |
-
)
|
90 |
-
|
91 |
-
overall_max_len = 0
|
92 |
|
93 |
-
for i in trange(0, total_batch_length, forward_batch_size
|
94 |
-
max_range = min(i
|
95 |
|
96 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
97 |
-
|
98 |
-
|
99 |
-
original_lens = torch.tensor(minibatch["length"], device="cuda")
|
100 |
minibatch.set_format(type="torch")
|
101 |
|
102 |
input_data_minibatch = minibatch["input_ids"]
|
103 |
-
input_data_minibatch =
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
107 |
with torch.no_grad():
|
108 |
outputs = model(
|
109 |
-
input_ids=input_data_minibatch.to("cuda"),
|
110 |
-
attention_mask=
|
111 |
)
|
112 |
|
113 |
embs_i = outputs.hidden_states[layer_to_quant]
|
114 |
-
|
115 |
if emb_mode == "cell":
|
116 |
-
|
117 |
-
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
118 |
-
if eos_present:
|
119 |
-
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
120 |
-
else:
|
121 |
-
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
|
122 |
-
else:
|
123 |
-
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
124 |
if summary_stat is None:
|
125 |
-
embs_list
|
126 |
elif summary_stat is not None:
|
127 |
# update tdigests with current batch for each emb dim
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
if summary_stat is None:
|
132 |
-
embs_list.append(embs_i)
|
133 |
-
elif summary_stat is not None:
|
134 |
-
for h in trange(len(minibatch)):
|
135 |
-
length_h = minibatch[h]["length"]
|
136 |
-
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
137 |
-
|
138 |
-
# double check dimensions before unsqueezing
|
139 |
-
embs_i_dim = embs_i.dim()
|
140 |
-
if embs_i_dim != 3:
|
141 |
-
logger.error(
|
142 |
-
f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
|
143 |
-
)
|
144 |
-
raise
|
145 |
-
|
146 |
-
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
147 |
-
dict_h = dict(zip(input_ids_h, embs_h))
|
148 |
-
for k in dict_h.keys():
|
149 |
-
accumulate_tdigests(
|
150 |
-
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
151 |
-
)
|
152 |
-
del embs_h
|
153 |
-
del dict_h
|
154 |
-
elif emb_mode == "cls":
|
155 |
-
cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
|
156 |
-
embs_list.append(cls_embs)
|
157 |
-
del cls_embs
|
158 |
-
|
159 |
-
overall_max_len = max(overall_max_len, max_len)
|
160 |
del outputs
|
161 |
del minibatch
|
162 |
del input_data_minibatch
|
163 |
del embs_i
|
164 |
-
|
165 |
-
torch.cuda.empty_cache()
|
166 |
-
|
167 |
if summary_stat is None:
|
168 |
-
|
169 |
-
embs_stack = torch.cat(embs_list, dim=0)
|
170 |
-
elif emb_mode == "gene":
|
171 |
-
embs_stack = pu.pad_tensor_list(
|
172 |
-
embs_list,
|
173 |
-
overall_max_len,
|
174 |
-
pad_token_id,
|
175 |
-
model_input_size,
|
176 |
-
1,
|
177 |
-
pu.pad_3d_tensor,
|
178 |
-
)
|
179 |
-
|
180 |
# calculate summary stat embs from approximated tdigests
|
181 |
elif summary_stat is not None:
|
182 |
-
if
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
embs_stack = torch.tensor(summary_emb_list)
|
188 |
-
elif emb_mode == "gene":
|
189 |
-
if summary_stat == "mean":
|
190 |
-
[
|
191 |
-
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
192 |
-
for gene in embs_tdigests_dict.keys()
|
193 |
-
]
|
194 |
-
elif summary_stat == "median":
|
195 |
-
[
|
196 |
-
update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
|
197 |
-
for gene in embs_tdigests_dict.keys()
|
198 |
-
]
|
199 |
-
return embs_tdigests_dict
|
200 |
|
201 |
return embs_stack
|
202 |
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
[
|
207 |
-
embs_tdigests[j].update(mean_embs[i, j].item())
|
208 |
-
for i in range(mean_embs.size(0))
|
209 |
-
for j in range(emb_dims)
|
210 |
-
]
|
211 |
-
|
212 |
-
|
213 |
-
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
214 |
-
embs_tdigests_dict[gene] = accumulate_tdigests(
|
215 |
-
embs_tdigests_dict[gene], gene_embs, emb_dims
|
216 |
-
)
|
217 |
-
|
218 |
-
|
219 |
-
def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
|
220 |
-
embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
|
221 |
-
|
222 |
-
|
223 |
-
def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
|
224 |
-
embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
|
225 |
-
|
226 |
-
|
227 |
-
def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
|
228 |
-
length_h = minibatch[h]["length"]
|
229 |
-
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
230 |
-
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
231 |
-
dict_h = dict(zip(input_ids_h, embs_h))
|
232 |
-
[
|
233 |
-
update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
|
234 |
-
for k in dict_h.keys()
|
235 |
-
]
|
236 |
-
|
237 |
-
|
238 |
-
def tdigest_mean(embs_tdigests, emb_dims):
|
239 |
-
return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
|
240 |
-
|
241 |
-
|
242 |
-
def tdigest_median(embs_tdigests, emb_dims):
|
243 |
-
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
244 |
-
|
245 |
|
246 |
-
def
|
247 |
-
embs_df = pd.DataFrame(embs.cpu()
|
248 |
if emb_labels is not None:
|
249 |
for label in emb_labels:
|
250 |
emb_label = downsampled_data[label]
|
251 |
embs_df[label] = emb_label
|
252 |
return embs_df
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
gene_set = {
|
257 |
-
element for sublist in downsampled_data["input_ids"] for element in sublist
|
258 |
-
}
|
259 |
-
gene_emb_dict = {k: [] for k in gene_set}
|
260 |
-
for i in range(embs.size()[0]):
|
261 |
-
length = downsampled_data[i]["length"]
|
262 |
-
dict_i = dict(
|
263 |
-
zip(
|
264 |
-
downsampled_data[i]["input_ids"][0:length],
|
265 |
-
embs[i, :, :].unsqueeze(dim=1),
|
266 |
-
)
|
267 |
-
)
|
268 |
-
for k in dict_i.keys():
|
269 |
-
gene_emb_dict[k].append(dict_i[k])
|
270 |
-
for k in gene_emb_dict.keys():
|
271 |
-
gene_emb_dict[k] = (
|
272 |
-
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
273 |
-
.cpu()
|
274 |
-
.numpy()
|
275 |
-
)
|
276 |
-
embs_df = pd.DataFrame(gene_emb_dict).T
|
277 |
-
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
278 |
-
return embs_df
|
279 |
-
|
280 |
-
|
281 |
-
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
282 |
-
only_embs_df = embs_df.iloc[:, :emb_dims]
|
283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
284 |
-
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
285 |
-
str
|
286 |
-
)
|
287 |
vars_dict = {"embs": only_embs_df.columns}
|
288 |
-
obs_dict = {"cell_id": list(only_embs_df.index),
|
|
|
289 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
290 |
-
sc.tl.pca(adata, svd_solver=
|
291 |
-
sc.pp.neighbors(adata
|
292 |
-
sc.tl.umap(adata
|
293 |
-
sns.set(rc={
|
294 |
sns.set_style("white")
|
295 |
-
default_kwargs_dict = {"size":
|
296 |
if kwargs_dict is not None:
|
297 |
default_kwargs_dict.update(kwargs_dict)
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
with plt.rc_context():
|
302 |
-
ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
|
303 |
-
ax.legend(
|
304 |
-
markerscale=2,
|
305 |
-
frameon=False,
|
306 |
-
loc="center left",
|
307 |
-
bbox_to_anchor=(1, 0.5),
|
308 |
-
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
309 |
-
)
|
310 |
-
plt.show()
|
311 |
-
plt.savefig(output_file, bbox_inches="tight")
|
312 |
-
|
313 |
|
314 |
def gen_heatmap_class_colors(labels, df):
|
315 |
-
pal = sns.cubehelix_palette(
|
316 |
-
len(Counter(labels).keys()),
|
317 |
-
light=0.9,
|
318 |
-
dark=0.1,
|
319 |
-
hue=1,
|
320 |
-
reverse=True,
|
321 |
-
start=1,
|
322 |
-
rot=-2,
|
323 |
-
)
|
324 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
325 |
colors = pd.Series(labels, index=df.index).map(lut)
|
326 |
return colors
|
327 |
-
|
328 |
-
|
329 |
def gen_heatmap_class_dict(classes, label_colors_series):
|
330 |
-
class_color_dict_df = pd.DataFrame(
|
331 |
-
{"classes": classes, "color": label_colors_series}
|
332 |
-
)
|
333 |
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
334 |
-
return dict(zip(class_color_dict_df["classes"],
|
335 |
-
|
336 |
-
|
337 |
def make_colorbar(embs_df, label):
|
338 |
-
labels = list(embs_df[label])
|
339 |
|
|
|
|
|
340 |
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
341 |
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
# create dictionary for colors and classes
|
344 |
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
345 |
return label_colors, label_color_dict
|
346 |
-
|
347 |
-
|
348 |
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
349 |
sns.set_style("white")
|
350 |
sns.set(font_scale=2)
|
351 |
plt.figure(figsize=(15, 15), dpi=150)
|
352 |
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
353 |
-
|
354 |
-
default_kwargs_dict = {
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
}
|
366 |
-
|
367 |
if kwargs_dict is not None:
|
368 |
default_kwargs_dict.update(kwargs_dict)
|
369 |
-
g = sns.clustermap(
|
370 |
-
embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
|
371 |
-
)
|
372 |
|
373 |
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
374 |
|
375 |
for label_color in list(label_color_dict.keys()):
|
376 |
-
g.ax_col_dendrogram.bar(
|
377 |
-
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
378 |
-
)
|
379 |
|
380 |
-
g.ax_col_dendrogram.legend(
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
facecolor="white",
|
386 |
-
)
|
387 |
-
plt.show()
|
388 |
-
logger.info(f"Output file: {output_file}")
|
389 |
-
plt.savefig(output_file, bbox_inches="tight")
|
390 |
|
|
|
391 |
|
392 |
class EmbExtractor:
|
393 |
valid_option_dict = {
|
394 |
-
"model_type": {"Pretrained",
|
395 |
"num_classes": {int},
|
396 |
-
"emb_mode": {"
|
397 |
"cell_emb_style": {"mean_pool"},
|
398 |
-
"gene_emb_style": {"mean_pool"},
|
399 |
"filter_data": {None, dict},
|
400 |
"max_ncells": {None, int},
|
401 |
"emb_layer": {-1, 0},
|
402 |
"emb_label": {None, list},
|
403 |
"labels_to_plot": {None, list},
|
404 |
"forward_batch_size": {int},
|
405 |
-
"token_dictionary_file": {None, str},
|
406 |
"nproc": {int},
|
407 |
-
"summary_stat": {None, "mean", "median"
|
408 |
}
|
409 |
-
|
410 |
def __init__(
|
411 |
self,
|
412 |
model_type="Pretrained",
|
413 |
num_classes=0,
|
414 |
-
emb_mode="
|
415 |
cell_emb_style="mean_pool",
|
416 |
-
gene_emb_style="mean_pool",
|
417 |
filter_data=None,
|
418 |
max_ncells=1000,
|
419 |
emb_layer=-1,
|
@@ -422,442 +256,238 @@ class EmbExtractor:
|
|
422 |
forward_batch_size=100,
|
423 |
nproc=4,
|
424 |
summary_stat=None,
|
425 |
-
token_dictionary_file=
|
426 |
):
|
427 |
"""
|
428 |
Initialize embedding extractor.
|
429 |
|
430 |
-
|
431 |
-
|
432 |
-
model_type : {"Pretrained",
|
433 |
-
|
434 |
num_classes : int
|
435 |
-
|
436 |
-
|
437 |
-
emb_mode : {"
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
| Currently only option is mean pooling of gene embeddings for given cell.
|
443 |
-
gene_emb_style : "mean_pool"
|
444 |
-
| Method for summarizing gene embeddings.
|
445 |
-
| Currently only option is mean pooling of contextual gene embeddings for given gene.
|
446 |
filter_data : None, dict
|
447 |
-
|
448 |
-
|
449 |
max_ncells : None, int
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
emb_layer : {-1, 0}
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
emb_label : None, list
|
460 |
-
|
461 |
labels_to_plot : None, list
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
forward_batch_size : int
|
467 |
-
|
468 |
nproc : int
|
469 |
-
|
470 |
-
summary_stat : {None, "mean", "median"
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
| Non-exact is slower but more memory-efficient.
|
475 |
token_dictionary_file : Path
|
476 |
-
|
477 |
-
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
478 |
-
|
479 |
-
**Examples:**
|
480 |
-
|
481 |
-
.. code-block :: python
|
482 |
-
|
483 |
-
>>> from geneformer import EmbExtractor
|
484 |
-
>>> embex = EmbExtractor(model_type="CellClassifier",
|
485 |
-
... num_classes=3,
|
486 |
-
... emb_mode="cell",
|
487 |
-
... filter_data={"cell_type":["cardiomyocyte"]},
|
488 |
-
... max_ncells=1000,
|
489 |
-
... emb_layer=-1,
|
490 |
-
... emb_label=["disease", "cell_type"],
|
491 |
-
... labels_to_plot=["disease", "cell_type"])
|
492 |
-
|
493 |
"""
|
494 |
|
495 |
self.model_type = model_type
|
496 |
self.num_classes = num_classes
|
497 |
self.emb_mode = emb_mode
|
498 |
self.cell_emb_style = cell_emb_style
|
499 |
-
self.gene_emb_style = gene_emb_style
|
500 |
self.filter_data = filter_data
|
501 |
self.max_ncells = max_ncells
|
502 |
self.emb_layer = emb_layer
|
503 |
self.emb_label = emb_label
|
504 |
self.labels_to_plot = labels_to_plot
|
505 |
-
self.token_dictionary_file = token_dictionary_file
|
506 |
self.forward_batch_size = forward_batch_size
|
507 |
self.nproc = nproc
|
508 |
-
|
509 |
-
self.summary_stat = None
|
510 |
-
self.exact_summary_stat = summary_stat
|
511 |
-
else:
|
512 |
-
self.summary_stat = summary_stat
|
513 |
-
self.exact_summary_stat = None
|
514 |
|
515 |
self.validate_options()
|
516 |
|
517 |
# load token dictionary (Ensembl IDs:token)
|
518 |
-
if self.token_dictionary_file is None:
|
519 |
-
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
520 |
with open(token_dictionary_file, "rb") as f:
|
521 |
self.gene_token_dict = pickle.load(f)
|
522 |
|
523 |
-
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
524 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
525 |
-
|
|
|
526 |
def validate_options(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
# confirm arguments are within valid options and compatible with each other
|
528 |
-
for attr_name,
|
529 |
attr_value = self.__dict__[attr_name]
|
530 |
-
if
|
531 |
if attr_value in valid_options:
|
532 |
continue
|
533 |
valid_type = False
|
534 |
for option in valid_options:
|
535 |
-
if (option in [int,
|
536 |
-
attr_value, option
|
537 |
-
):
|
538 |
valid_type = True
|
539 |
break
|
540 |
if valid_type:
|
541 |
continue
|
542 |
logger.error(
|
543 |
-
f"Invalid option for {attr_name}. "
|
544 |
f"Valid options for {attr_name}: {valid_options}"
|
545 |
)
|
546 |
raise
|
547 |
-
|
548 |
if self.filter_data is not None:
|
549 |
-
for key,
|
550 |
-
if
|
551 |
self.filter_data[key] = [value]
|
552 |
logger.warning(
|
553 |
-
"Values in filter_data dict must be lists. "
|
554 |
-
f"Changing {key} value to list ([{value}])."
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
output_directory,
|
562 |
-
output_prefix,
|
563 |
-
output_torch_embs=False,
|
564 |
-
cell_state=None,
|
565 |
-
):
|
566 |
"""
|
567 |
Extract embeddings from input data and save as results in output_directory.
|
568 |
|
569 |
-
|
570 |
-
|
571 |
model_directory : Path
|
572 |
-
|
573 |
input_data_file : Path
|
574 |
-
|
575 |
output_directory : Path
|
576 |
-
|
577 |
output_prefix : str
|
578 |
-
|
579 |
-
output_torch_embs : bool
|
580 |
-
| Whether or not to also output the embeddings as a tensor.
|
581 |
-
| Note, if true, will output embeddings as both dataframe and tensor.
|
582 |
-
cell_state : dict
|
583 |
-
| Cell state key and value for state embedding extraction.
|
584 |
-
|
585 |
-
**Examples:**
|
586 |
-
|
587 |
-
.. code-block :: python
|
588 |
-
|
589 |
-
>>> embs = embex.extract_embs("path/to/model",
|
590 |
-
... "path/to/input_data",
|
591 |
-
... "path/to/output_directory",
|
592 |
-
... "output_prefix")
|
593 |
-
|
594 |
"""
|
595 |
|
596 |
-
filtered_input_data =
|
597 |
-
|
598 |
-
)
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
)
|
613 |
-
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
614 |
-
embs = get_embs(
|
615 |
-
model=model,
|
616 |
-
filtered_input_data=downsampled_data,
|
617 |
-
emb_mode=self.emb_mode,
|
618 |
-
layer_to_quant=layer_to_quant,
|
619 |
-
pad_token_id=self.pad_token_id,
|
620 |
-
forward_batch_size=self.forward_batch_size,
|
621 |
-
token_gene_dict=self.token_gene_dict,
|
622 |
-
summary_stat=self.summary_stat,
|
623 |
-
)
|
624 |
-
|
625 |
-
if self.emb_mode == "cell":
|
626 |
-
if self.summary_stat is None:
|
627 |
-
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
628 |
-
elif self.summary_stat is not None:
|
629 |
-
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
630 |
-
elif self.emb_mode == "gene":
|
631 |
-
if self.summary_stat is None:
|
632 |
-
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
633 |
-
elif self.summary_stat is not None:
|
634 |
-
embs_df = pd.DataFrame(embs).T
|
635 |
-
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
636 |
-
elif self.emb_mode == "cls":
|
637 |
-
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
638 |
|
639 |
# save embeddings to output_path
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
embs = torch.median(embs, dim=0)[0]
|
653 |
-
emb_dims = pu.get_model_emb_dims(model)
|
654 |
-
embs_df = pd.DataFrame(
|
655 |
-
embs_df[0 : emb_dims - 1].median(axis="rows"),
|
656 |
-
columns=[self.exact_summary_stat],
|
657 |
-
).T
|
658 |
-
|
659 |
-
if cell_state is not None:
|
660 |
-
return embs
|
661 |
-
else:
|
662 |
-
if output_torch_embs:
|
663 |
-
return embs_df, embs
|
664 |
-
else:
|
665 |
-
return embs_df
|
666 |
-
|
667 |
-
def get_state_embs(
|
668 |
-
self,
|
669 |
-
cell_states_to_model,
|
670 |
-
model_directory,
|
671 |
-
input_data_file,
|
672 |
-
output_directory,
|
673 |
-
output_prefix,
|
674 |
-
output_torch_embs=True,
|
675 |
-
):
|
676 |
-
"""
|
677 |
-
Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
|
678 |
-
|
679 |
-
**Parameters:**
|
680 |
-
|
681 |
-
cell_states_to_model : None, dict
|
682 |
-
| Cell states to model if testing perturbations that achieve goal state change.
|
683 |
-
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
684 |
-
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
685 |
-
| start_state: value in the state_key column that specifies the start state
|
686 |
-
| goal_state: value in the state_key column taht specifies the goal end state
|
687 |
-
| alt_states: list of values in the state_key column that specify the alternate end states
|
688 |
-
| For example:
|
689 |
-
| {"state_key": "disease",
|
690 |
-
| "start_state": "dcm",
|
691 |
-
| "goal_state": "nf",
|
692 |
-
| "alt_states": ["hcm", "other1", "other2"]}
|
693 |
-
model_directory : Path
|
694 |
-
| Path to directory containing model
|
695 |
-
input_data_file : Path
|
696 |
-
| Path to directory containing .dataset inputs
|
697 |
-
output_directory : Path
|
698 |
-
| Path to directory where embedding data will be saved as csv
|
699 |
-
output_prefix : str
|
700 |
-
| Prefix for output file
|
701 |
-
output_torch_embs : bool
|
702 |
-
| Whether or not to also output the embeddings as a tensor.
|
703 |
-
| Note, if true, will output embeddings as both dataframe and tensor.
|
704 |
-
|
705 |
-
**Outputs**
|
706 |
-
|
707 |
-
| Outputs state_embs_dict for use with in silico perturber.
|
708 |
-
| Format is dictionary of embedding positions of each cell state to model shifts from/towards.
|
709 |
-
| Keys specify each possible cell state to model.
|
710 |
-
| Values are target embedding positions as torch.tensor.
|
711 |
-
| For example:
|
712 |
-
| {"nf": emb_nf,
|
713 |
-
| "hcm": emb_hcm,
|
714 |
-
| "dcm": emb_dcm,
|
715 |
-
| "other1": emb_other1,
|
716 |
-
| "other2": emb_other2}
|
717 |
-
"""
|
718 |
-
|
719 |
-
pu.validate_cell_states_to_model(cell_states_to_model)
|
720 |
-
valid_summary_stats = ["exact_mean", "exact_median"]
|
721 |
-
if self.exact_summary_stat not in valid_summary_stats:
|
722 |
-
logger.error(
|
723 |
-
"For extracting state embs, summary_stat in EmbExtractor "
|
724 |
-
f"must be set to option in {valid_summary_stats}"
|
725 |
-
)
|
726 |
-
raise
|
727 |
-
|
728 |
-
if self.emb_label is not None:
|
729 |
-
logger.error(
|
730 |
-
"For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
|
731 |
-
)
|
732 |
-
raise
|
733 |
|
734 |
-
state_embs_dict = dict()
|
735 |
-
state_key = cell_states_to_model["state_key"]
|
736 |
-
for k, v in cell_states_to_model.items():
|
737 |
-
if k == "state_key":
|
738 |
-
continue
|
739 |
-
elif (k == "start_state") or (k == "goal_state"):
|
740 |
-
state_embs_dict[v] = self.extract_embs(
|
741 |
-
model_directory,
|
742 |
-
input_data_file,
|
743 |
-
output_directory,
|
744 |
-
output_prefix,
|
745 |
-
output_torch_embs,
|
746 |
-
cell_state={state_key: v},
|
747 |
-
)
|
748 |
-
else: # k == "alt_states"
|
749 |
-
for alt_state in v:
|
750 |
-
state_embs_dict[alt_state] = self.extract_embs(
|
751 |
-
model_directory,
|
752 |
-
input_data_file,
|
753 |
-
output_directory,
|
754 |
-
output_prefix,
|
755 |
-
output_torch_embs,
|
756 |
-
cell_state={state_key: alt_state},
|
757 |
-
)
|
758 |
-
|
759 |
-
output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
|
760 |
-
with open(output_path, "wb") as fp:
|
761 |
-
pickle.dump(state_embs_dict, fp)
|
762 |
-
|
763 |
-
return state_embs_dict
|
764 |
-
|
765 |
-
def plot_embs(
|
766 |
-
self,
|
767 |
-
embs,
|
768 |
-
plot_style,
|
769 |
-
output_directory,
|
770 |
-
output_prefix,
|
771 |
-
max_ncells_to_plot=1000,
|
772 |
-
kwargs_dict=None,
|
773 |
-
):
|
774 |
"""
|
775 |
Plot embeddings, coloring by provided labels.
|
776 |
|
777 |
-
|
778 |
-
|
779 |
embs : pandas.core.frame.DataFrame
|
780 |
-
|
781 |
plot_style : str
|
782 |
-
|
783 |
output_directory : Path
|
784 |
-
|
785 |
output_prefix : str
|
786 |
-
|
787 |
max_ncells_to_plot : None, int
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
kwargs_dict : dict
|
792 |
-
|
793 |
-
|
794 |
-
**Examples:**
|
795 |
-
|
796 |
-
.. code-block :: python
|
797 |
-
|
798 |
-
>>> embex.plot_embs(embs=embs,
|
799 |
-
... plot_style="heatmap",
|
800 |
-
... output_directory="path/to/output_directory",
|
801 |
-
... output_prefix="output_prefix")
|
802 |
-
|
803 |
"""
|
804 |
-
|
805 |
-
if plot_style not in ["heatmap",
|
806 |
logger.error(
|
807 |
-
"Invalid option for 'plot_style'. "
|
|
|
808 |
)
|
809 |
raise
|
810 |
-
|
811 |
if (plot_style == "umap") and (self.labels_to_plot is None):
|
812 |
-
logger.error(
|
|
|
|
|
813 |
raise
|
814 |
-
|
815 |
-
if max_ncells_to_plot
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
if self.emb_label is None:
|
826 |
label_len = 0
|
827 |
else:
|
828 |
label_len = len(self.emb_label)
|
829 |
-
|
830 |
emb_dims = embs.shape[1] - label_len
|
831 |
-
|
832 |
if self.emb_label is None:
|
833 |
emb_labels = None
|
834 |
else:
|
835 |
emb_labels = embs.columns[emb_dims:]
|
836 |
-
|
837 |
if plot_style == "umap":
|
838 |
for label in self.labels_to_plot:
|
839 |
if label not in emb_labels:
|
840 |
logger.warning(
|
841 |
-
f"Label {label} from labels_to_plot "
|
842 |
-
f"not present in provided embeddings dataframe."
|
843 |
-
)
|
844 |
continue
|
845 |
-
output_prefix_label = output_prefix + f"_umap_{label}"
|
846 |
-
output_file = (
|
847 |
-
|
848 |
-
|
849 |
-
plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
|
850 |
-
|
851 |
if plot_style == "heatmap":
|
852 |
for label in self.labels_to_plot:
|
853 |
if label not in emb_labels:
|
854 |
logger.warning(
|
855 |
-
f"Label {label} from labels_to_plot "
|
856 |
-
f"not present in provided embeddings dataframe."
|
857 |
-
)
|
858 |
continue
|
859 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
860 |
-
output_file = (
|
861 |
-
|
862 |
-
).with_suffix(".pdf")
|
863 |
-
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|
|
1 |
"""
|
2 |
Geneformer embedding extractor.
|
3 |
|
4 |
+
Usage:
|
5 |
+
from geneformer import EmbExtractor
|
6 |
+
embex = EmbExtractor(model_type="CellClassifier",
|
7 |
+
num_classes=3,
|
8 |
+
emb_mode="cell",
|
9 |
+
cell_emb_style="mean_pool",
|
10 |
+
filter_data={"cell_type":["cardiomyocyte"]},
|
11 |
+
max_ncells=1000,
|
12 |
+
max_ncells_to_plot=1000,
|
13 |
+
emb_layer=-1,
|
14 |
+
emb_label=["disease","cell_type"],
|
15 |
+
labels_to_plot=["disease","cell_type"],
|
16 |
+
forward_batch_size=100,
|
17 |
+
nproc=16,
|
18 |
+
summary_stat=None)
|
19 |
+
embs = embex.extract_embs("path/to/model",
|
20 |
+
"path/to/input_data",
|
21 |
+
"path/to/output_directory",
|
22 |
+
"output_prefix")
|
23 |
+
embex.plot_embs(embs=embs,
|
24 |
+
plot_style="heatmap",
|
25 |
+
output_directory="path/to/output_directory",
|
26 |
+
output_prefix="output_prefix")
|
27 |
+
|
28 |
"""
|
29 |
|
30 |
# imports
|
31 |
import logging
|
|
|
|
|
|
|
|
|
32 |
import anndata
|
33 |
import matplotlib.pyplot as plt
|
34 |
+
import numpy as np
|
35 |
import pandas as pd
|
36 |
+
import pickle
|
37 |
+
from tdigest import TDigest
|
38 |
import scanpy as sc
|
39 |
import seaborn as sns
|
40 |
import torch
|
41 |
+
from collections import Counter
|
42 |
+
from pathlib import Path
|
43 |
+
from tqdm.notebook import trange
|
44 |
+
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
45 |
|
46 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
|
|
47 |
|
48 |
+
from .in_silico_perturber import downsample_and_sort, \
|
49 |
+
gen_attention_mask, \
|
50 |
+
get_model_input_size, \
|
51 |
+
load_and_filter, \
|
52 |
+
load_model, \
|
53 |
+
mean_nonpadding_embs, \
|
54 |
+
pad_tensor_list, \
|
55 |
+
quant_layers
|
56 |
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
|
59 |
# extract embeddings
|
60 |
+
def get_embs(model,
|
61 |
+
filtered_input_data,
|
62 |
+
emb_mode,
|
63 |
+
layer_to_quant,
|
64 |
+
pad_token_id,
|
65 |
+
forward_batch_size,
|
66 |
+
summary_stat):
|
67 |
+
|
68 |
+
model_input_size = get_model_input_size(model)
|
|
|
|
|
|
|
|
|
69 |
total_batch_length = len(filtered_input_data)
|
70 |
+
|
71 |
if summary_stat is None:
|
72 |
embs_list = []
|
73 |
elif summary_stat is not None:
|
74 |
+
# test embedding extraction for example cell and extract # emb dims
|
75 |
+
example = filtered_input_data.select([i for i in range(1)])
|
76 |
+
example.set_format(type="torch")
|
77 |
+
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
78 |
+
# initiate tdigests for # of emb dims
|
79 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
for i in trange(0, total_batch_length, forward_batch_size):
|
82 |
+
max_range = min(i+forward_batch_size, total_batch_length)
|
83 |
|
84 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
85 |
+
max_len = max(minibatch["length"])
|
86 |
+
original_lens = torch.tensor(minibatch["length"]).to("cuda")
|
|
|
87 |
minibatch.set_format(type="torch")
|
88 |
|
89 |
input_data_minibatch = minibatch["input_ids"]
|
90 |
+
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
91 |
+
max_len,
|
92 |
+
pad_token_id,
|
93 |
+
model_input_size)
|
94 |
+
|
95 |
with torch.no_grad():
|
96 |
outputs = model(
|
97 |
+
input_ids = input_data_minibatch.to("cuda"),
|
98 |
+
attention_mask = gen_attention_mask(minibatch)
|
99 |
)
|
100 |
|
101 |
embs_i = outputs.hidden_states[layer_to_quant]
|
102 |
+
|
103 |
if emb_mode == "cell":
|
104 |
+
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if summary_stat is None:
|
106 |
+
embs_list += [mean_embs]
|
107 |
elif summary_stat is not None:
|
108 |
# update tdigests with current batch for each emb dim
|
109 |
+
# note: tdigest batch update known to be slow so updating serially
|
110 |
+
[embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
|
111 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
del outputs
|
113 |
del minibatch
|
114 |
del input_data_minibatch
|
115 |
del embs_i
|
116 |
+
del mean_embs
|
117 |
+
torch.cuda.empty_cache()
|
118 |
+
|
119 |
if summary_stat is None:
|
120 |
+
embs_stack = torch.cat(embs_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
# calculate summary stat embs from approximated tdigests
|
122 |
elif summary_stat is not None:
|
123 |
+
if summary_stat == "mean":
|
124 |
+
summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
|
125 |
+
elif summary_stat == "median":
|
126 |
+
summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
127 |
+
embs_stack = torch.tensor(summary_emb_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
return embs_stack
|
130 |
|
131 |
+
def test_emb(model, example, layer_to_quant):
|
132 |
+
with torch.no_grad():
|
133 |
+
outputs = model(
|
134 |
+
input_ids = example.to("cuda")
|
135 |
+
)
|
136 |
|
137 |
+
embs_test = outputs.hidden_states[layer_to_quant]
|
138 |
+
return embs_test.size()[2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
def label_embs(embs, downsampled_data, emb_labels):
|
141 |
+
embs_df = pd.DataFrame(embs.cpu())
|
142 |
if emb_labels is not None:
|
143 |
for label in emb_labels:
|
144 |
emb_label = downsampled_data[label]
|
145 |
embs_df[label] = emb_label
|
146 |
return embs_df
|
147 |
|
148 |
+
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
149 |
+
only_embs_df = embs_df.iloc[:,:emb_dims]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
151 |
+
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
|
|
|
|
|
152 |
vars_dict = {"embs": only_embs_df.columns}
|
153 |
+
obs_dict = {"cell_id": list(only_embs_df.index),
|
154 |
+
f"{label}": list(embs_df[label])}
|
155 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
156 |
+
sc.tl.pca(adata, svd_solver='arpack')
|
157 |
+
sc.pp.neighbors(adata)
|
158 |
+
sc.tl.umap(adata)
|
159 |
+
sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
|
160 |
sns.set_style("white")
|
161 |
+
default_kwargs_dict = {"palette":"Set2", "size":200}
|
162 |
if kwargs_dict is not None:
|
163 |
default_kwargs_dict.update(kwargs_dict)
|
164 |
+
|
165 |
+
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
def gen_heatmap_class_colors(labels, df):
|
168 |
+
pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
170 |
colors = pd.Series(labels, index=df.index).map(lut)
|
171 |
return colors
|
172 |
+
|
|
|
173 |
def gen_heatmap_class_dict(classes, label_colors_series):
|
174 |
+
class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
|
|
|
|
|
175 |
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
176 |
+
return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
|
177 |
+
|
|
|
178 |
def make_colorbar(embs_df, label):
|
|
|
179 |
|
180 |
+
labels = list(embs_df[label])
|
181 |
+
|
182 |
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
183 |
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
184 |
|
185 |
+
for i,row in label_colors.iterrows():
|
186 |
+
colors=row[0]
|
187 |
+
if len(colors)!=3 or any(np.isnan(colors)):
|
188 |
+
print(i,colors)
|
189 |
+
|
190 |
+
label_colors.isna().sum()
|
191 |
+
|
192 |
# create dictionary for colors and classes
|
193 |
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
194 |
return label_colors, label_color_dict
|
195 |
+
|
|
|
196 |
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
197 |
sns.set_style("white")
|
198 |
sns.set(font_scale=2)
|
199 |
plt.figure(figsize=(15, 15), dpi=150)
|
200 |
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
201 |
+
|
202 |
+
default_kwargs_dict = {"row_cluster": True,
|
203 |
+
"col_cluster": True,
|
204 |
+
"row_colors": label_colors,
|
205 |
+
"standard_scale": 1,
|
206 |
+
"linewidths": 0,
|
207 |
+
"xticklabels": False,
|
208 |
+
"yticklabels": False,
|
209 |
+
"figsize": (15,15),
|
210 |
+
"center": 0,
|
211 |
+
"cmap": "magma"}
|
212 |
+
|
|
|
|
|
213 |
if kwargs_dict is not None:
|
214 |
default_kwargs_dict.update(kwargs_dict)
|
215 |
+
g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
|
|
|
|
|
216 |
|
217 |
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
218 |
|
219 |
for label_color in list(label_color_dict.keys()):
|
220 |
+
g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
|
|
|
|
|
221 |
|
222 |
+
l1 = g.ax_col_dendrogram.legend(title=f"{label}",
|
223 |
+
loc="lower center",
|
224 |
+
ncol=4,
|
225 |
+
bbox_to_anchor=(0.5, 1),
|
226 |
+
facecolor="white")
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
plt.savefig(output_file, bbox_inches='tight')
|
229 |
|
230 |
class EmbExtractor:
|
231 |
valid_option_dict = {
|
232 |
+
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
233 |
"num_classes": {int},
|
234 |
+
"emb_mode": {"cell","gene"},
|
235 |
"cell_emb_style": {"mean_pool"},
|
|
|
236 |
"filter_data": {None, dict},
|
237 |
"max_ncells": {None, int},
|
238 |
"emb_layer": {-1, 0},
|
239 |
"emb_label": {None, list},
|
240 |
"labels_to_plot": {None, list},
|
241 |
"forward_batch_size": {int},
|
|
|
242 |
"nproc": {int},
|
243 |
+
"summary_stat": {None, "mean", "median"},
|
244 |
}
|
|
|
245 |
def __init__(
|
246 |
self,
|
247 |
model_type="Pretrained",
|
248 |
num_classes=0,
|
249 |
+
emb_mode="cell",
|
250 |
cell_emb_style="mean_pool",
|
|
|
251 |
filter_data=None,
|
252 |
max_ncells=1000,
|
253 |
emb_layer=-1,
|
|
|
256 |
forward_batch_size=100,
|
257 |
nproc=4,
|
258 |
summary_stat=None,
|
259 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
260 |
):
|
261 |
"""
|
262 |
Initialize embedding extractor.
|
263 |
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
model_type : {"Pretrained","GeneClassifier","CellClassifier"}
|
267 |
+
Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
268 |
num_classes : int
|
269 |
+
If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
270 |
+
For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
271 |
+
emb_mode : {"cell","gene"}
|
272 |
+
Whether to output cell or gene embeddings.
|
273 |
+
cell_emb_style : "mean_pool"
|
274 |
+
Method for summarizing cell embeddings.
|
275 |
+
Currently only option is mean pooling of gene embeddings for given cell.
|
|
|
|
|
|
|
|
|
276 |
filter_data : None, dict
|
277 |
+
Default is to extract embeddings from all input data.
|
278 |
+
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
279 |
max_ncells : None, int
|
280 |
+
Maximum number of cells to extract embeddings from.
|
281 |
+
Default is 1000 cells randomly sampled from input data.
|
282 |
+
If None, will extract embeddings from all cells.
|
283 |
emb_layer : {-1, 0}
|
284 |
+
Embedding layer to extract.
|
285 |
+
The last layer is most specifically weighted to optimize the given learning objective.
|
286 |
+
Generally, it is best to extract the 2nd to last layer to get a more general representation.
|
287 |
+
-1: 2nd to last layer
|
288 |
+
0: last layer
|
289 |
emb_label : None, list
|
290 |
+
List of column name(s) in .dataset to add as labels to embedding output.
|
291 |
labels_to_plot : None, list
|
292 |
+
Cell labels to plot.
|
293 |
+
Shown as color bar in heatmap.
|
294 |
+
Shown as cell color in umap.
|
295 |
+
Plotting umap requires labels to plot.
|
296 |
forward_batch_size : int
|
297 |
+
Batch size for forward pass.
|
298 |
nproc : int
|
299 |
+
Number of CPU processes to use.
|
300 |
+
summary_stat : {None, "mean", "median"}
|
301 |
+
If not None, outputs only approximated mean or median embedding of input data.
|
302 |
+
Recommended if encountering memory constraints while generating goal embedding positions.
|
303 |
+
Slower but more memory-efficient.
|
|
|
304 |
token_dictionary_file : Path
|
305 |
+
Path to pickle file containing token dictionary (Ensembl ID:token).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
"""
|
307 |
|
308 |
self.model_type = model_type
|
309 |
self.num_classes = num_classes
|
310 |
self.emb_mode = emb_mode
|
311 |
self.cell_emb_style = cell_emb_style
|
|
|
312 |
self.filter_data = filter_data
|
313 |
self.max_ncells = max_ncells
|
314 |
self.emb_layer = emb_layer
|
315 |
self.emb_label = emb_label
|
316 |
self.labels_to_plot = labels_to_plot
|
|
|
317 |
self.forward_batch_size = forward_batch_size
|
318 |
self.nproc = nproc
|
319 |
+
self.summary_stat = summary_stat
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
self.validate_options()
|
322 |
|
323 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
324 |
with open(token_dictionary_file, "rb") as f:
|
325 |
self.gene_token_dict = pickle.load(f)
|
326 |
|
|
|
327 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
328 |
+
|
329 |
+
|
330 |
def validate_options(self):
|
331 |
+
# first disallow options under development
|
332 |
+
if self.emb_mode == "gene":
|
333 |
+
logger.error(
|
334 |
+
"Extraction and plotting of gene-level embeddings currently under development. " \
|
335 |
+
"Current valid option for 'emb_mode': 'cell'"
|
336 |
+
)
|
337 |
+
raise
|
338 |
+
|
339 |
# confirm arguments are within valid options and compatible with each other
|
340 |
+
for attr_name,valid_options in self.valid_option_dict.items():
|
341 |
attr_value = self.__dict__[attr_name]
|
342 |
+
if type(attr_value) not in {list, dict}:
|
343 |
if attr_value in valid_options:
|
344 |
continue
|
345 |
valid_type = False
|
346 |
for option in valid_options:
|
347 |
+
if (option in [int,list,dict]) and isinstance(attr_value, option):
|
|
|
|
|
348 |
valid_type = True
|
349 |
break
|
350 |
if valid_type:
|
351 |
continue
|
352 |
logger.error(
|
353 |
+
f"Invalid option for {attr_name}. " \
|
354 |
f"Valid options for {attr_name}: {valid_options}"
|
355 |
)
|
356 |
raise
|
357 |
+
|
358 |
if self.filter_data is not None:
|
359 |
+
for key,value in self.filter_data.items():
|
360 |
+
if type(value) != list:
|
361 |
self.filter_data[key] = [value]
|
362 |
logger.warning(
|
363 |
+
"Values in filter_data dict must be lists. " \
|
364 |
+
f"Changing {key} value to list ([{value}]).")
|
365 |
+
|
366 |
+
def extract_embs(self,
|
367 |
+
model_directory,
|
368 |
+
input_data_file,
|
369 |
+
output_directory,
|
370 |
+
output_prefix):
|
|
|
|
|
|
|
|
|
|
|
371 |
"""
|
372 |
Extract embeddings from input data and save as results in output_directory.
|
373 |
|
374 |
+
Parameters
|
375 |
+
----------
|
376 |
model_directory : Path
|
377 |
+
Path to directory containing model
|
378 |
input_data_file : Path
|
379 |
+
Path to directory containing .dataset inputs
|
380 |
output_directory : Path
|
381 |
+
Path to directory where embedding data will be saved as csv
|
382 |
output_prefix : str
|
383 |
+
Prefix for output file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
"""
|
385 |
|
386 |
+
filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
|
387 |
+
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
388 |
+
model = load_model(self.model_type, self.num_classes, model_directory)
|
389 |
+
layer_to_quant = quant_layers(model)+self.emb_layer
|
390 |
+
embs = get_embs(model,
|
391 |
+
downsampled_data,
|
392 |
+
self.emb_mode,
|
393 |
+
layer_to_quant,
|
394 |
+
self.pad_token_id,
|
395 |
+
self.forward_batch_size,
|
396 |
+
self.summary_stat)
|
397 |
+
|
398 |
+
if self.summary_stat is None:
|
399 |
+
embs_df = label_embs(embs, downsampled_data, self.emb_label)
|
400 |
+
elif self.summary_stat is not None:
|
401 |
+
embs_df = pd.DataFrame(embs.cpu()).T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
# save embeddings to output_path
|
404 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
405 |
+
embs_df.to_csv(output_path)
|
406 |
+
|
407 |
+
return embs_df
|
408 |
+
|
409 |
+
def plot_embs(self,
|
410 |
+
embs,
|
411 |
+
plot_style,
|
412 |
+
output_directory,
|
413 |
+
output_prefix,
|
414 |
+
max_ncells_to_plot=1000,
|
415 |
+
kwargs_dict=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
"""
|
418 |
Plot embeddings, coloring by provided labels.
|
419 |
|
420 |
+
Parameters
|
421 |
+
----------
|
422 |
embs : pandas.core.frame.DataFrame
|
423 |
+
Pandas dataframe containing embeddings output from extract_embs
|
424 |
plot_style : str
|
425 |
+
Style of plot: "heatmap" or "umap"
|
426 |
output_directory : Path
|
427 |
+
Path to directory where plots will be saved as pdf
|
428 |
output_prefix : str
|
429 |
+
Prefix for output file
|
430 |
max_ncells_to_plot : None, int
|
431 |
+
Maximum number of cells to plot.
|
432 |
+
Default is 1000 cells randomly sampled from embeddings.
|
433 |
+
If None, will plot embeddings from all cells.
|
434 |
kwargs_dict : dict
|
435 |
+
Dictionary of kwargs to pass to plotting function.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
"""
|
437 |
+
|
438 |
+
if plot_style not in ["heatmap","umap"]:
|
439 |
logger.error(
|
440 |
+
"Invalid option for 'plot_style'. " \
|
441 |
+
"Valid options: {'heatmap','umap'}"
|
442 |
)
|
443 |
raise
|
444 |
+
|
445 |
if (plot_style == "umap") and (self.labels_to_plot is None):
|
446 |
+
logger.error(
|
447 |
+
"Plotting UMAP requires 'labels_to_plot'. "
|
448 |
+
)
|
449 |
raise
|
450 |
+
|
451 |
+
if max_ncells_to_plot > self.max_ncells:
|
452 |
+
max_ncells_to_plot = self.max_ncells
|
453 |
+
logger.warning(
|
454 |
+
"max_ncells_to_plot must be <= max_ncells. " \
|
455 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}.")
|
456 |
+
|
457 |
+
if (max_ncells_to_plot is not None) \
|
458 |
+
and (max_ncells_to_plot < self.max_ncells):
|
459 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
460 |
+
|
461 |
if self.emb_label is None:
|
462 |
label_len = 0
|
463 |
else:
|
464 |
label_len = len(self.emb_label)
|
465 |
+
|
466 |
emb_dims = embs.shape[1] - label_len
|
467 |
+
|
468 |
if self.emb_label is None:
|
469 |
emb_labels = None
|
470 |
else:
|
471 |
emb_labels = embs.columns[emb_dims:]
|
472 |
+
|
473 |
if plot_style == "umap":
|
474 |
for label in self.labels_to_plot:
|
475 |
if label not in emb_labels:
|
476 |
logger.warning(
|
477 |
+
f"Label {label} from labels_to_plot " \
|
478 |
+
f"not present in provided embeddings dataframe.")
|
|
|
479 |
continue
|
480 |
+
output_prefix_label = "_" + output_prefix + f"_umap_{label}"
|
481 |
+
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
482 |
+
plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
|
483 |
+
|
|
|
|
|
484 |
if plot_style == "heatmap":
|
485 |
for label in self.labels_to_plot:
|
486 |
if label not in emb_labels:
|
487 |
logger.warning(
|
488 |
+
f"Label {label} from labels_to_plot " \
|
489 |
+
f"not present in provided embeddings dataframe.")
|
|
|
490 |
continue
|
491 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
492 |
+
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
493 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|
|
|
geneformer/ensembl_mapping_dict_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
|
3 |
-
size 3957652
|
|
|
|
|
|
|
|
geneformer/evaluation_utils.py
DELETED
@@ -1,287 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import math
|
3 |
-
import pickle
|
4 |
-
from pathlib import Path
|
5 |
-
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
-
import seaborn as sns
|
10 |
-
import torch
|
11 |
-
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
12 |
-
from sklearn import preprocessing
|
13 |
-
from sklearn.metrics import (
|
14 |
-
ConfusionMatrixDisplay,
|
15 |
-
accuracy_score,
|
16 |
-
auc,
|
17 |
-
confusion_matrix,
|
18 |
-
f1_score,
|
19 |
-
roc_curve,
|
20 |
-
)
|
21 |
-
from tqdm.auto import trange
|
22 |
-
|
23 |
-
from . import TOKEN_DICTIONARY_FILE
|
24 |
-
from .emb_extractor import make_colorbar
|
25 |
-
|
26 |
-
logger = logging.getLogger(__name__)
|
27 |
-
|
28 |
-
|
29 |
-
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
30 |
-
if max_len is None:
|
31 |
-
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
32 |
-
|
33 |
-
# load token dictionary (Ensembl IDs:token)
|
34 |
-
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
35 |
-
gene_token_dict = pickle.load(f)
|
36 |
-
|
37 |
-
def pad_label_example(example):
|
38 |
-
example[label_name] = np.pad(
|
39 |
-
example[label_name],
|
40 |
-
(0, max_len - len(example["input_ids"])),
|
41 |
-
mode="constant",
|
42 |
-
constant_values=-100,
|
43 |
-
)
|
44 |
-
example["input_ids"] = np.pad(
|
45 |
-
example["input_ids"],
|
46 |
-
(0, max_len - len(example["input_ids"])),
|
47 |
-
mode="constant",
|
48 |
-
constant_values=gene_token_dict.get("<pad>"),
|
49 |
-
)
|
50 |
-
example["attention_mask"] = (
|
51 |
-
example["input_ids"] != gene_token_dict.get("<pad>")
|
52 |
-
).astype(int)
|
53 |
-
return example
|
54 |
-
|
55 |
-
padded_batch = cell_batch.map(pad_label_example)
|
56 |
-
return padded_batch
|
57 |
-
|
58 |
-
|
59 |
-
# Function to find the largest number smaller
|
60 |
-
# than or equal to N that is divisible by k
|
61 |
-
def find_largest_div(N, K):
|
62 |
-
rem = N % K
|
63 |
-
if rem == 0:
|
64 |
-
return N
|
65 |
-
else:
|
66 |
-
return N - rem
|
67 |
-
|
68 |
-
|
69 |
-
def vote(logit_list):
|
70 |
-
m = max(logit_list)
|
71 |
-
logit_list.index(m)
|
72 |
-
indices = [i for i, x in enumerate(logit_list) if x == m]
|
73 |
-
if len(indices) > 1:
|
74 |
-
return "tie"
|
75 |
-
else:
|
76 |
-
return indices[0]
|
77 |
-
|
78 |
-
|
79 |
-
def py_softmax(vector):
|
80 |
-
e = np.exp(vector)
|
81 |
-
return e / e.sum()
|
82 |
-
|
83 |
-
|
84 |
-
def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
85 |
-
if classifier_type == "gene":
|
86 |
-
label_name = "labels"
|
87 |
-
elif classifier_type == "cell":
|
88 |
-
label_name = "label"
|
89 |
-
|
90 |
-
predict_logits = []
|
91 |
-
predict_labels = []
|
92 |
-
model.eval()
|
93 |
-
|
94 |
-
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
95 |
-
evalset_len = len(evalset)
|
96 |
-
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
97 |
-
if len(evalset) - max_divisible == 1:
|
98 |
-
evalset_len = max_divisible
|
99 |
-
|
100 |
-
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
101 |
-
|
102 |
-
disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
|
103 |
-
for i in trange(0, evalset_len, forward_batch_size):
|
104 |
-
max_range = min(i + forward_batch_size, evalset_len)
|
105 |
-
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
106 |
-
padded_batch = preprocess_classifier_batch(
|
107 |
-
batch_evalset, max_evalset_len, label_name
|
108 |
-
)
|
109 |
-
padded_batch.set_format(type="torch")
|
110 |
-
|
111 |
-
input_data_batch = padded_batch["input_ids"]
|
112 |
-
attn_msk_batch = padded_batch["attention_mask"]
|
113 |
-
label_batch = padded_batch[label_name]
|
114 |
-
with torch.no_grad():
|
115 |
-
outputs = model(
|
116 |
-
input_ids=input_data_batch.to("cuda"),
|
117 |
-
attention_mask=attn_msk_batch.to("cuda"),
|
118 |
-
labels=label_batch.to("cuda"),
|
119 |
-
)
|
120 |
-
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
121 |
-
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
122 |
-
|
123 |
-
enable_progress_bar()
|
124 |
-
logits_by_cell = torch.cat(predict_logits)
|
125 |
-
last_dim = len(logits_by_cell.shape) - 1
|
126 |
-
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
|
127 |
-
labels_by_cell = torch.cat(predict_labels)
|
128 |
-
all_labels = torch.flatten(labels_by_cell)
|
129 |
-
logit_label_paired = [
|
130 |
-
item
|
131 |
-
for item in list(zip(all_logits.tolist(), all_labels.tolist()))
|
132 |
-
if item[1] != -100
|
133 |
-
]
|
134 |
-
y_pred = [vote(item[0]) for item in logit_label_paired]
|
135 |
-
y_true = [item[1] for item in logit_label_paired]
|
136 |
-
logits_list = [item[0] for item in logit_label_paired]
|
137 |
-
return y_pred, y_true, logits_list
|
138 |
-
|
139 |
-
|
140 |
-
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
141 |
-
conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
|
142 |
-
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
143 |
-
acc = accuracy_score(y_true, y_pred)
|
144 |
-
roc_metrics = None # roc metrics not reported for multiclass
|
145 |
-
if num_classes == 2:
|
146 |
-
y_score = [py_softmax(item)[1] for item in logits_list]
|
147 |
-
fpr, tpr, _ = roc_curve(y_true, y_score)
|
148 |
-
mean_fpr = np.linspace(0, 1, 100)
|
149 |
-
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
150 |
-
interp_tpr[0] = 0.0
|
151 |
-
tpr_wt = len(tpr)
|
152 |
-
roc_auc = auc(fpr, tpr)
|
153 |
-
roc_metrics = {
|
154 |
-
"fpr": fpr,
|
155 |
-
"tpr": tpr,
|
156 |
-
"interp_tpr": interp_tpr,
|
157 |
-
"auc": roc_auc,
|
158 |
-
"tpr_wt": tpr_wt,
|
159 |
-
}
|
160 |
-
return conf_mat, macro_f1, acc, roc_metrics
|
161 |
-
|
162 |
-
|
163 |
-
# get cross-validated mean and sd metrics
|
164 |
-
def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
165 |
-
wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
|
166 |
-
all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
|
167 |
-
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
168 |
-
mean_tpr[-1] = 1.0
|
169 |
-
all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
|
170 |
-
roc_auc = np.sum(all_weighted_roc_auc)
|
171 |
-
roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
|
172 |
-
return mean_tpr, roc_auc, roc_auc_sd
|
173 |
-
|
174 |
-
|
175 |
-
# plot ROC curve
|
176 |
-
def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
|
177 |
-
fig = plt.figure()
|
178 |
-
fig.set_size_inches(10, 8)
|
179 |
-
sns.set(font_scale=2)
|
180 |
-
sns.set_style("white")
|
181 |
-
lw = 3
|
182 |
-
for model_name in roc_metric_dict.keys():
|
183 |
-
mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
|
184 |
-
mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
|
185 |
-
roc_auc = roc_metric_dict[model_name]["roc_auc"]
|
186 |
-
roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
|
187 |
-
color = model_style_dict[model_name]["color"]
|
188 |
-
linestyle = model_style_dict[model_name]["linestyle"]
|
189 |
-
if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
|
190 |
-
label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
|
191 |
-
else:
|
192 |
-
label = f"{model_name} (AUC {roc_auc:0.2f})"
|
193 |
-
plt.plot(
|
194 |
-
mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
|
195 |
-
)
|
196 |
-
|
197 |
-
plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
|
198 |
-
plt.xlim([0.0, 1.0])
|
199 |
-
plt.ylim([0.0, 1.05])
|
200 |
-
plt.xlabel("False Positive Rate")
|
201 |
-
plt.ylabel("True Positive Rate")
|
202 |
-
plt.title(title)
|
203 |
-
plt.legend(loc="lower right")
|
204 |
-
|
205 |
-
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
206 |
-
plt.savefig(output_file, bbox_inches="tight")
|
207 |
-
plt.show()
|
208 |
-
|
209 |
-
|
210 |
-
# plot confusion matrix
|
211 |
-
def plot_confusion_matrix(
|
212 |
-
conf_mat_df, title, output_dir, output_prefix, custom_class_order
|
213 |
-
):
|
214 |
-
fig = plt.figure()
|
215 |
-
fig.set_size_inches(10, 10)
|
216 |
-
sns.set(font_scale=1)
|
217 |
-
sns.set_style("whitegrid", {"axes.grid": False})
|
218 |
-
if custom_class_order is not None:
|
219 |
-
conf_mat_df = conf_mat_df.reindex(
|
220 |
-
index=custom_class_order, columns=custom_class_order
|
221 |
-
)
|
222 |
-
display_labels = generate_display_labels(conf_mat_df)
|
223 |
-
conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
|
224 |
-
display = ConfusionMatrixDisplay(
|
225 |
-
confusion_matrix=conf_mat, display_labels=display_labels
|
226 |
-
)
|
227 |
-
display.plot(cmap="Blues", values_format=".2g")
|
228 |
-
plt.title(title)
|
229 |
-
plt.show()
|
230 |
-
|
231 |
-
output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
|
232 |
-
display.figure_.savefig(output_file, bbox_inches="tight")
|
233 |
-
|
234 |
-
|
235 |
-
def generate_display_labels(conf_mat_df):
|
236 |
-
display_labels = []
|
237 |
-
i = 0
|
238 |
-
for label in conf_mat_df.index:
|
239 |
-
display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
|
240 |
-
i = i + 1
|
241 |
-
return display_labels
|
242 |
-
|
243 |
-
|
244 |
-
def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
|
245 |
-
sns.set(font_scale=2)
|
246 |
-
plt.figure(figsize=(10, 10), dpi=150)
|
247 |
-
label_colors, label_color_dict = make_colorbar(predictions_df, "true")
|
248 |
-
predictions_df = predictions_df.drop(columns=["true"])
|
249 |
-
predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
|
250 |
-
predict_label_list = [label for label in predictions_df.columns]
|
251 |
-
predict_colors = pd.DataFrame(
|
252 |
-
pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
|
253 |
-
)
|
254 |
-
|
255 |
-
default_kwargs_dict = {
|
256 |
-
"row_cluster": False,
|
257 |
-
"col_cluster": False,
|
258 |
-
"row_colors": label_colors,
|
259 |
-
"col_colors": predict_colors,
|
260 |
-
"linewidths": 0,
|
261 |
-
"xticklabels": False,
|
262 |
-
"yticklabels": False,
|
263 |
-
"center": 0,
|
264 |
-
"cmap": "vlag",
|
265 |
-
}
|
266 |
-
|
267 |
-
if kwargs_dict is not None:
|
268 |
-
default_kwargs_dict.update(kwargs_dict)
|
269 |
-
g = sns.clustermap(predictions_df, **default_kwargs_dict)
|
270 |
-
|
271 |
-
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
272 |
-
|
273 |
-
for label_color in list(label_color_dict.keys()):
|
274 |
-
g.ax_col_dendrogram.bar(
|
275 |
-
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
276 |
-
)
|
277 |
-
|
278 |
-
g.ax_col_dendrogram.legend(
|
279 |
-
title=f"{title}",
|
280 |
-
loc="lower center",
|
281 |
-
ncol=4,
|
282 |
-
bbox_to_anchor=(0.5, 1),
|
283 |
-
facecolor="white",
|
284 |
-
)
|
285 |
-
|
286 |
-
output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
|
287 |
-
plt.savefig(output_file, bbox_inches="tight")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:eac0fb0b3007267871b6305ac0003ceba19d4f28d85686cb9067ecf142787869
|
3 |
-
size 584125
|
|
|
|
|
|
|
|