garbage / inference.py
nastasev's picture
Upload 6 files
98b671b verified
raw
history blame contribute delete
632 Bytes
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
model_name = "saved_model"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model.eval()
image_path = '/path/'
image = Image.open(image_path).convert('RGB')
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
classes = model.config.id2label
print(f"Predicted class: {classes[predicted_class_idx]}")