Model Overview: The model presented in this paper builds on the BigBird architecture with a similar approach detailed in our paper titled "Leveraging Large Language Models for Metagenomic Analysis" This model is optimized to enhance the performance of BigBird for large gene sequence data. Trained specifically on gene sequences, it aims to uncover valuable insights within metagenomic data and is evaluated across various tasks, including classification and sequence embedding.
Model Architecture:
- Base Model: BigBird transformer architecture
- Tokenizer: Custom K-mer Tokenizer with k-mer length of 6 and overlapping tokens
- Training: Trained on a diverse dataset of 497 genes from 2000 bacterial and archaeal genomes
- Embeddings: Generates sequence embeddings using mean pooling of hidden states
Dataset: Scorpio Gene-Taxa Benchmark Dataset: https://zenodo.org/records/12964684
Steps to Use the Model:
Install KmerTokenizer:
pip install git+https://github.com/MsAlEhR/KmerTokenizer.git
Example Code:
from KmerTokenizer import KmerTokenizer from transformers import AutoModel import torch # Example gene sequence seq = "ATTTTTTTTTTTCCCCCCCCCCCGGGGGGGGATCGATGC" # Initialize the tokenizer tokenizer = KmerTokenizer(kmerlen=6, overlapping=True, maxlen=4096) tokenized_output = tokenizer.kmer_tokenize(seq) pad_token_id = 2 # Set pad token ID # Create attention mask (1 for tokens, 0 for padding) attention_mask = torch.tensor([1 if token != pad_token_id else 0 for token in tokenized_output], dtype=torch.long).unsqueeze(0) # Convert tokenized output to LongTensor and add batch dimension inputs = torch.tensor([tokenized_output], dtype=torch.long) # Load the pre-trained BigBird model model = AutoModel.from_pretrained("MsAlEhR/MetaBERTa-bigbird-gene", output_hidden_states=True) # Generate hidden states outputs = model(input_ids=inputs, attention_mask=attention_mask) # Get embeddings from the last hidden state embeddings = outputs.hidden_states[-1] # Expand attention mask to match the embedding dimensions expanded_attention_mask = attention_mask.unsqueeze(-1) # Compute mean sequence embeddings mean_sequence_embeddings = torch.sum(expanded_attention_mask * embeddings, dim=1) / torch.sum(expanded_attention_mask, dim=1)
Citation: For a detailed overview of leveraging large language models for metagenomic analysis, refer to our paper:
Refahi, M.S., Sokhansanj, B.A., & Rosen, G.L. (2023). Leveraging Large Language Models for Metagenomic Analysis. IEEE SPMB.
Refahi, M., Sokhansanj, B.A., Mell, J.C., Brown, J., Yoo, H., Hearne, G. and Rosen, G., 2024. Scorpio: Enhancing Embeddings to Improve Downstream Analysis of DNA sequences. bioRxiv, pp.2024-07.
- Downloads last month
- 429