File size: 1,783 Bytes
f510a94
5670eb6
f510a94
 
5670eb6
f510a94
 
 
8125ad0
 
 
 
9a36bac
8125ad0
 
9a36bac
 
 
 
 
51de9f3
9a36bac
ba9f8b2
 
 
932c2cf
 
 
87e2f24
ba9f8b2
4eeeace
ba9f8b2
 
 
 
 
 
87e2f24
 
ba9f8b2
 
 
 
 
 
 
87e2f24
39823af
d8ec515
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
---
base_model: WinKawaks/vit-tiny-patch16-224
datasets:
- 0-ma/geometric-shapes
license: apache-2.0
metrics:
- accuracy
pipeline_tag: image-classification
---

# Model Card for VIT Geometric Shapes Dataset Tiny

## Training Dataset

- **Repository:** https://huggingface.co/datasets/0-ma/geometric-shapes

## Base Model

- **Repository:** https://huggingface.co/models/WinKawaks/vit-tiny-patch16-224

## Accuracy

 - Accuracy on dataset 0-ma/geometric-shapes [test] : 0.9138095238095238

# Loading and using the model
    import numpy as np
    from PIL import Image
    from transformers import AutoImageProcessor, AutoModelForImageClassification 
    import requests
    labels =  [
        "None",
        "Circle",
        "Triangle",
        "Square",
        "Pentagon",
        "Hexagon"
    ] 
    images = [Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_circle.jpg", stream=True).raw), 
            Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_pentagone.jpg", stream=True).raw)]
    feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
    model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
    inputs = feature_extractor(images=images, return_tensors="pt")
    logits = model(**inputs)['logits'].cpu().detach().numpy()
    predictions = np.argmax(logits, axis=1)    
    predicted_labels = [labels[prediction] for prediction in predictions]
    print(predicted_labels)

## Model generation
The model has been created using the 'train_shape_detector.py.py' of the project from the project https://github.com/0-ma/geometric-shape-detector. No external code sources were used.