hsarfraz commited on
Commit
176bacd
1 Parent(s): 546fdfb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -0
README.md CHANGED
@@ -38,5 +38,62 @@ The config.json file has assocociated label2id updated to reflect all labels tha
38
 
39
  For inference use image size with width: 1920 px and height: 2560 px
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  [base]: https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip
 
38
 
39
  For inference use image size with width: 1920 px and height: 2560 px
40
 
41
+ ## Sample Code for Document Inference
42
+ ```python
43
+ # load dependencies
44
+ import torch
45
+ from transformers import DonutSwinModel, DonutSwinPreTrainedModel,DonutProcessor
46
+ from torch import nn
47
+ from PIL import Image
48
+
49
+ #
50
+ class DonutForImageClassification(DonutSwinPreTrainedModel):
51
+ def __init__(self, config):
52
+ super().__init__(config)
53
+ self.num_labels = config.num_labels
54
+ self.swin = DonutSwinModel(config)
55
+ self.dropout = nn.Dropout(0.5)
56
+ self.classifier = nn.Linear(self.swin.num_features, config.num_labels)
57
+
58
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
59
+ outputs = self.swin(pixel_values)
60
+ pooled_output = outputs[1]
61
+ pooled_output = self.dropout(pooled_output)
62
+ logits = self.classifier(pooled_output)
63
+ return logits
64
+
65
+ sModelName = 'hsarfraz/donut-irs-tax-docs-classifier'
66
+ processor = DonutProcessor.from_pretrained(sModelName)
67
+ model = DonutForImageClassification.from_pretrained(sModelName)
68
+
69
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
70
+ model.to(device)
71
+
72
+ model.eval()
73
+
74
+ # load test image
75
+ sTestImagePath ='replace this with document image path' # i.e.
76
+ # open image
77
+ img = Image.open(sTestImagePath)
78
+ # resize image to width 1920 and height 2560 - fine tuned model is trained with this width and height
79
+ img_new = img.resize((1920,2560),Image.Resampling.LANCZOS)
80
+
81
+ # perfoem inference
82
+ predicted_label = ''
83
+ with torch.no_grad():
84
+ pixel_values = processor(img_new.convert("RGB"), return_tensors="pt").pixel_values
85
+ print(pixel_values.shape)
86
+ pixel_values = pixel_values.to(device)
87
+ outputs = model(pixel_values)
88
+ logits, predicted = torch.max(outputs.data, 1)
89
+ pval = predicted.cpu().numpy()[0]
90
+ predicted_label = model.config.id2label[pval]
91
+
92
+ print('---------------------------------- ')
93
+ print('Document Image Classification: ',predicted_label)
94
+
95
+
96
+ ```
97
+
98
 
99
  [base]: https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip