add handler.py for HF Dedicated Inference
Browse filesThis PR adds a [Inference Handler](https://huggingface.co/docs/inference-endpoints/guides/custom_handler) to this model.
This is required for using it in a HuggingFace Dedicated Endpoint, since this product does not has a Image-To-Text task available out-of-the-box.
The handler implements the inference as described in the (model docs)[https://huggingface.co/docs/transformers/model_doc/nougat] but with the StoppingCriteria implemented on the notebook example.
- handler.py +102 -0
- requirements.txt +2 -0
handler.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
|
2 |
+
import torch.cuda
|
3 |
+
import io
|
4 |
+
import base64
|
5 |
+
from PIL import Image
|
6 |
+
from typing import Dict, Any
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
class RunningVarTorch:
|
10 |
+
def __init__(self, L=15, norm=False):
|
11 |
+
self.values = None
|
12 |
+
self.L = L
|
13 |
+
self.norm = norm
|
14 |
+
|
15 |
+
def push(self, x: torch.Tensor):
|
16 |
+
assert x.dim() == 1
|
17 |
+
if self.values is None:
|
18 |
+
self.values = x[:, None]
|
19 |
+
elif self.values.shape[1] < self.L:
|
20 |
+
self.values = torch.cat((self.values, x[:, None]), 1)
|
21 |
+
else:
|
22 |
+
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
23 |
+
|
24 |
+
def variance(self):
|
25 |
+
if self.values is None:
|
26 |
+
return
|
27 |
+
if self.norm:
|
28 |
+
return torch.var(self.values, 1) / self.values.shape[1]
|
29 |
+
else:
|
30 |
+
return torch.var(self.values, 1)
|
31 |
+
|
32 |
+
class StoppingCriteriaScores(StoppingCriteria):
|
33 |
+
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
34 |
+
super().__init__()
|
35 |
+
self.threshold = threshold
|
36 |
+
self.vars = RunningVarTorch(norm=True)
|
37 |
+
self.varvars = RunningVarTorch(L=window_size)
|
38 |
+
self.stop_inds = defaultdict(int)
|
39 |
+
self.stopped = defaultdict(bool)
|
40 |
+
self.size = 0
|
41 |
+
self.window_size = window_size
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
45 |
+
last_scores = scores[-1]
|
46 |
+
self.vars.push(last_scores.max(1)[0].float().cpu())
|
47 |
+
self.varvars.push(self.vars.variance())
|
48 |
+
self.size += 1
|
49 |
+
if self.size < self.window_size:
|
50 |
+
return False
|
51 |
+
|
52 |
+
varvar = self.varvars.variance()
|
53 |
+
for b in range(len(last_scores)):
|
54 |
+
if varvar[b] < self.threshold:
|
55 |
+
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
56 |
+
self.stopped[b] = self.stop_inds[b] >= self.size
|
57 |
+
else:
|
58 |
+
self.stop_inds[b] = int(
|
59 |
+
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
self.stop_inds[b] = 0
|
63 |
+
self.stopped[b] = False
|
64 |
+
return all(self.stopped.values()) and len(self.stopped) > 0
|
65 |
+
|
66 |
+
class EndpointHandler():
|
67 |
+
def __init__(self, path="facebook/nougat-base"):
|
68 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
69 |
+
self.processor = NougatProcessor.from_pretrained(path)
|
70 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(path)
|
71 |
+
self.model = self.model.to(self.device)
|
72 |
+
|
73 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
data (Dict): The payload with the text prompt
|
77 |
+
and generation parameters.
|
78 |
+
"""
|
79 |
+
# Get inputs
|
80 |
+
input = data.pop("inputs", None)
|
81 |
+
parameters = data.pop("parameters", None)
|
82 |
+
fix_markdown = data.pop("fix_markdown", None)
|
83 |
+
if input is None:
|
84 |
+
raise ValueError("Missing image.")
|
85 |
+
# autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
|
86 |
+
binary_data = base64.b64decode(input)
|
87 |
+
|
88 |
+
image = Image.open(io.BytesIO(binary_data))
|
89 |
+
pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
|
90 |
+
outputs = self.model.generate(
|
91 |
+
pixel_values=pixel_values.to(self.model.device),
|
92 |
+
min_length=1,
|
93 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
94 |
+
return_dict_in_generate=True,
|
95 |
+
output_scores=True,
|
96 |
+
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
|
97 |
+
**parameters,
|
98 |
+
)
|
99 |
+
generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
|
100 |
+
prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
|
101 |
+
|
102 |
+
return {"generated_text": prediction}
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
python-Levenshtein
|
2 |
+
nltk
|