Update README.md
Browse files
README.md
CHANGED
@@ -38,5 +38,62 @@ The config.json file has assocociated label2id updated to reflect all labels tha
|
|
38 |
|
39 |
For inference use image size with width: 1920 px and height: 2560 px
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
[base]: https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip
|
|
|
38 |
|
39 |
For inference use image size with width: 1920 px and height: 2560 px
|
40 |
|
41 |
+
## Sample Code for Document Inference
|
42 |
+
```python
|
43 |
+
# load dependencies
|
44 |
+
import torch
|
45 |
+
from transformers import DonutSwinModel, DonutSwinPreTrainedModel,DonutProcessor
|
46 |
+
from torch import nn
|
47 |
+
from PIL import Image
|
48 |
+
|
49 |
+
#
|
50 |
+
class DonutForImageClassification(DonutSwinPreTrainedModel):
|
51 |
+
def __init__(self, config):
|
52 |
+
super().__init__(config)
|
53 |
+
self.num_labels = config.num_labels
|
54 |
+
self.swin = DonutSwinModel(config)
|
55 |
+
self.dropout = nn.Dropout(0.5)
|
56 |
+
self.classifier = nn.Linear(self.swin.num_features, config.num_labels)
|
57 |
+
|
58 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
59 |
+
outputs = self.swin(pixel_values)
|
60 |
+
pooled_output = outputs[1]
|
61 |
+
pooled_output = self.dropout(pooled_output)
|
62 |
+
logits = self.classifier(pooled_output)
|
63 |
+
return logits
|
64 |
+
|
65 |
+
sModelName = 'hsarfraz/donut-irs-tax-docs-classifier'
|
66 |
+
processor = DonutProcessor.from_pretrained(sModelName)
|
67 |
+
model = DonutForImageClassification.from_pretrained(sModelName)
|
68 |
+
|
69 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
70 |
+
model.to(device)
|
71 |
+
|
72 |
+
model.eval()
|
73 |
+
|
74 |
+
# load test image
|
75 |
+
sTestImagePath ='replace this with document image path' # i.e.
|
76 |
+
# open image
|
77 |
+
img = Image.open(sTestImagePath)
|
78 |
+
# resize image to width 1920 and height 2560 - fine tuned model is trained with this width and height
|
79 |
+
img_new = img.resize((1920,2560),Image.Resampling.LANCZOS)
|
80 |
+
|
81 |
+
# perfoem inference
|
82 |
+
predicted_label = ''
|
83 |
+
with torch.no_grad():
|
84 |
+
pixel_values = processor(img_new.convert("RGB"), return_tensors="pt").pixel_values
|
85 |
+
print(pixel_values.shape)
|
86 |
+
pixel_values = pixel_values.to(device)
|
87 |
+
outputs = model(pixel_values)
|
88 |
+
logits, predicted = torch.max(outputs.data, 1)
|
89 |
+
pval = predicted.cpu().numpy()[0]
|
90 |
+
predicted_label = model.config.id2label[pval]
|
91 |
+
|
92 |
+
print('---------------------------------- ')
|
93 |
+
print('Document Image Classification: ',predicted_label)
|
94 |
+
|
95 |
+
|
96 |
+
```
|
97 |
+
|
98 |
|
99 |
[base]: https://huggingface.co/naver-clova-ix/donut-base-finetuned-rvlcdip
|