matteopilotto commited on
Commit
a28da73
1 Parent(s): 05d39bf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +44 -1
README.md CHANGED
@@ -1,4 +1,47 @@
1
  ---
2
  datasets:
3
  - Matthijs/snacks
4
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  datasets:
3
  - Matthijs/snacks
4
+ ---
5
+
6
+ # Vision Transformer fine-tuned on `Matthijs/snacks` dataset
7
+
8
+ Vision Transformer (ViT) model pre-trained on ImageNet-21k and fine-tuned [**Matthijs/snacks**](https://huggingface.co/datasets/Matthijs/snacks) dataset for 5 epochs using various data augmentation transformations from `torchvision`.
9
+
10
+ The model achieves a **94.97%** and **94.43%** accuracy on the validation and test set, respectively.
11
+
12
+ ## Data augmentation pipeline
13
+
14
+ The code block below shows the various transformations applied during pre-processing to augment the original dataset.
15
+ The augmented images where generated on-the-fly with the `set_transform` method.
16
+
17
+ ```python
18
+ from transformers import ViTFeatureExtractor
19
+ from torchvision.transforms import (
20
+ Compose,
21
+ Normalize,
22
+ Resize,
23
+ RandomResizedCrop,
24
+ RandomHorizontalFlip,
25
+ RandomAdjustSharpness,
26
+ ToTensor
27
+ )
28
+
29
+ checkpoint = 'google/vit-base-patch16-224-in21k'
30
+ feature_extractor = ViTFeatureExtractor.from_pretrained(checkpoint)
31
+
32
+ # train
33
+ train_aug_transforms = Compose([
34
+ RandomResizedCrop(size=feature_extractor.size),
35
+ RandomHorizontalFlip(p=0.5),
36
+ RandomAdjustSharpness(sharpness_factor=5, p=0.5),
37
+ ToTensor(),
38
+ Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
39
+ ])
40
+
41
+ # validation/test
42
+ valid_aug_transforms = Compose([
43
+ Resize(size=(feature_extractor.size, feature_extractor.size)),
44
+ ToTensor(),
45
+ Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
46
+ ])
47
+ ```