0-ma's picture
Update README.md
724ecf1 verified
|
raw
history blame
1.64 kB
---
base_model: nvidia/mit-b0
datasets:
- 0-ma/geometric-shapes
license: other
metrics:
- accuracy
pipeline_tag: image-classification
---
# Model Card for Mit-B0 Geometric Shapes Dataset
## Training Dataset
- **Repository:** https://huggingface.co/datasets/0-ma/geometric-shapes
## Base Model
- **Repository:** https://huggingface.co/models/nvidia/mit-b0
## Accuracy
- Accuracy on dataset 0-ma/geometric-shapes [test] : ???
# Loading and using the model
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import requests
labels = [
"None",
"Circle",
"Triangle",
"Square",
"Pentagon",
"Hexagon"
]
images = [Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_circle.jpg", stream=True).raw),
Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_pentagone.jpg", stream=True).raw)]
feature_extractor = AutoImageProcessor.from_pretrained('0-ma/mit-b0-geometric-shapes')
model = AutoModelForImageClassification.from_pretrained('0-ma/mit-b0-geometric-shapes')
inputs = feature_extractor(images=images, return_tensors="pt")
logits = model(**inputs)['logits'].cpu().detach().numpy()
predictions = np.argmax(logits, axis=1)
predicted_labels = [labels[prediction] for prediction in predictions]
print(predicted_labels)
### License
The license for this model can be found [here](https://github.com/NVlabs/SegFormer/blob/master/LICENSE).