CellrepDINO Model
This is a custom DINO model for extracting rich representations of cell microscopy in condensed vector/array form. The forward method of the cellrepDINO model gives embeddings that can be used for relevant downstream tasks like perturbation prediction, mechanism of action (MoA) classification, nuclei size shape estimation, etc. Simply train a basic linear or logistic model using the embeddings.
Model Details
- Architecture: DINOv2
- Default Model Size: Giant (1.1 B parameters)
- Patch Size: 14
- Default image size: 1024
- Default resize size: 518
- Default center crop: 518
Setup
Please create an environment and run pip install torch transformers Pillow numpy pandas torchvision omegaconf
. Activate this new environment.
Example Usage
There are different types of embeddings of embeddings one can extract, we recommend the mean/median embeddings over the patch tokens or the class token embedding.
from transformers import AutoModel, AutoProcessor
from PIL import Image
import torch
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and processor
model = AutoModel.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True)
# Move model to device
model = model.to(device)
model.eval()
# For multiple images:
image_paths = ["image1.png", "image2.png"]
images = [Image.open(path) for path in image_paths]
# Process batch of images
# if you want different rezise and centercrop sizes, please specificy the resize_size, crop_size parameters below
batch_inputs = processor.preprocess(images=images, resize_size = 518, crop_size = 518, return_tensors="pt")
# Move image tensors to device
batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()}
# Generate embeddings for batch
with torch.no_grad():
batch_outputs = model(**batch_inputs)
mean_embeddings = batch_outputs['mean_pooled']
median_embeddings = batch_outputs['median_pooled']
cls_embeddings = batch_outputs['cls_token']
- Downloads last month
- 511
Inference API (serverless) does not yet support transformers models for this pipeline type.