jmbrito commited on
Commit
fc62d33
1 Parent(s): abfeced

add handler.py for HF Dedicated Inference

Browse files

This PR adds a [Inference Handler](https://huggingface.co/docs/inference-endpoints/guides/custom_handler) to this model.
This is required for using it in a HuggingFace Dedicated Endpoint, since this product does not has a Image-To-Text task available out-of-the-box.

The handler implements the inference as described in the (model docs)[https://huggingface.co/docs/transformers/model_doc/nougat] but with the StoppingCriteria implemented on the notebook example.

Files changed (2) hide show
  1. handler.py +102 -0
  2. requirements.txt +2 -0
handler.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
2
+ import torch.cuda
3
+ import io
4
+ import base64
5
+ from PIL import Image
6
+ from typing import Dict, Any
7
+ from collections import defaultdict
8
+
9
+ class RunningVarTorch:
10
+ def __init__(self, L=15, norm=False):
11
+ self.values = None
12
+ self.L = L
13
+ self.norm = norm
14
+
15
+ def push(self, x: torch.Tensor):
16
+ assert x.dim() == 1
17
+ if self.values is None:
18
+ self.values = x[:, None]
19
+ elif self.values.shape[1] < self.L:
20
+ self.values = torch.cat((self.values, x[:, None]), 1)
21
+ else:
22
+ self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
23
+
24
+ def variance(self):
25
+ if self.values is None:
26
+ return
27
+ if self.norm:
28
+ return torch.var(self.values, 1) / self.values.shape[1]
29
+ else:
30
+ return torch.var(self.values, 1)
31
+
32
+ class StoppingCriteriaScores(StoppingCriteria):
33
+ def __init__(self, threshold: float = 0.015, window_size: int = 200):
34
+ super().__init__()
35
+ self.threshold = threshold
36
+ self.vars = RunningVarTorch(norm=True)
37
+ self.varvars = RunningVarTorch(L=window_size)
38
+ self.stop_inds = defaultdict(int)
39
+ self.stopped = defaultdict(bool)
40
+ self.size = 0
41
+ self.window_size = window_size
42
+
43
+ @torch.no_grad()
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
45
+ last_scores = scores[-1]
46
+ self.vars.push(last_scores.max(1)[0].float().cpu())
47
+ self.varvars.push(self.vars.variance())
48
+ self.size += 1
49
+ if self.size < self.window_size:
50
+ return False
51
+
52
+ varvar = self.varvars.variance()
53
+ for b in range(len(last_scores)):
54
+ if varvar[b] < self.threshold:
55
+ if self.stop_inds[b] > 0 and not self.stopped[b]:
56
+ self.stopped[b] = self.stop_inds[b] >= self.size
57
+ else:
58
+ self.stop_inds[b] = int(
59
+ min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
60
+ )
61
+ else:
62
+ self.stop_inds[b] = 0
63
+ self.stopped[b] = False
64
+ return all(self.stopped.values()) and len(self.stopped) > 0
65
+
66
+ class EndpointHandler():
67
+ def __init__(self, path="facebook/nougat-base"):
68
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ self.processor = NougatProcessor.from_pretrained(path)
70
+ self.model = VisionEncoderDecoderModel.from_pretrained(path)
71
+ self.model = self.model.to(self.device)
72
+
73
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
74
+ """
75
+ Args:
76
+ data (Dict): The payload with the text prompt
77
+ and generation parameters.
78
+ """
79
+ # Get inputs
80
+ input = data.pop("inputs", None)
81
+ parameters = data.pop("parameters", None)
82
+ fix_markdown = data.pop("fix_markdown", None)
83
+ if input is None:
84
+ raise ValueError("Missing image.")
85
+ # autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
86
+ binary_data = base64.b64decode(input)
87
+
88
+ image = Image.open(io.BytesIO(binary_data))
89
+ pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
90
+ outputs = self.model.generate(
91
+ pixel_values=pixel_values.to(self.model.device),
92
+ min_length=1,
93
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
94
+ return_dict_in_generate=True,
95
+ output_scores=True,
96
+ stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
97
+ **parameters,
98
+ )
99
+ generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
100
+ prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
101
+
102
+ return {"generated_text": prediction}
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python-Levenshtein
2
+ nltk