dnabert2 / README.md
czl's picture
Upload folder using huggingface_hub
71b4e8a verified
|
raw
history blame
3.23 kB
metadata
metrics:
  - matthews_correlation
  - f1
tags:
  - biology
  - medical

DNABERT-2 (modified)

Modified to configure the use of flash attention.

Below are works from the original repository and jaandoui.

This version of DNABERT2 has been changed to be able to output the attention too, for attention analysis.

To the author of DNABERT2, feel free to use those modifications.

Use --model_name_or_path jaandoui/DNABERT2-AttentionExtracted instead of the original repository to have access to the attention.

Most of the modifications were done in Bert_Layer.py. It has been modified especially for fine tuning and hasn't been tried for pretraining. Before or next to each modification, you can find "JAANDOUI" so to see al modifications, search for "JAANDOUI". "JAANDOUI TODO" means that if that part is going to be used, maybe something might be missing.

Now in Trainer (or CustomTrainer if overwritten) in compute_loss(..) when defining the model: outputs = model(**inputs, return_dict=True, output_attentions=True) activate the extraction of attention: output_attentions=True (and return_dict=True (optional)). You can now extract the attention in outputs.attentions Note than the output has a third dimension, mostly of value 12, referring to the layer outputs.attentions[-1] refers to the attention of the last layer. Read more about model outputs here: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/output#transformers.utils.ModelOutput

I'm also not using Triton, therefore cannot guarantee that it will work with it.

I also read that there were some problems with extracting attention when using Flash Attention here: https://github.com/huggingface/transformers/issues/28903 Not sure if that is relevant for us, since it's about Mistral models.

I'm still exploring this attention, please don't take it as if it works 100%. I'll update the repository when I'm sure.

The official link to DNABERT2 DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome .

READ ME OF THE OFFICIAL DNABERT2: We sincerely appreciate the MosaicML team for the MosaicBERT implementation, which serves as the base of DNABERT-2 development.

DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.

To load the model from huggingface:

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

To calculate the embedding of a dna sequence

dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states = model(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768