KYAGABA's picture
Create README.md
de146ec verified

Combined Multimodal Model

This model performs medical image classification and report generation using a custom architecture that combines a video model and a text generation model.

Model Details

  • Architecture: Custom model combining a 3D ResNet (r3d_18) and BioBART.
  • Tasks:
    • Classification: Classifies medical images into one of four classes: acute, normal, chronic, or lacunar.
    • Report Generation: Generates medical reports based on the input images.

Usage

import torch
from transformers import AutoTokenizer
from model import CombinedModel, ImageToTextProjector
from torchvision import models

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("YOUR_HF_USERNAME/combined-multimodal-model")

# Initialize models
video_model = models.video.r3d_18(pretrained=True)
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)

report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")

projector = ImageToTextProjector(512, report_generator.config.d_model)

num_classes = 4
combined_model = CombinedModel(video_model, report_generator, num_classes, projector)

# Load state dict
state_dict = torch.hub.load_state_dict_from_url(
    "https://huggingface.co/YOUR_HF_USERNAME/combined-multimodal-model/resolve/main/pytorch_model.bin",
    map_location=torch.device('cpu')
)
combined_model.load_state_dict(state_dict)
combined_model.eval()

# Now you can use combined_model for inference